PyTorch 教程系列:https://blog.csdn.net/qq_38962621/category_10652223.html
PyTorch教程-4:PyTorch中网络的训练与测试
基本原理
对于要训练的模型,首先我们需要定义其结构,实例化一个用于计算Loss的loss_function
和一个用于更新参数的optimizer
,之后的事情就比较简单了,只要准备训练数据,然后设定训练的代数(或者停止条件)就可以进行迭代的训练。最后保存模型。对于验证的模型,只要将数据传输进训练好的模型中就能得到预测的结果,当然这个过程通常是不需要计算梯度的,所以通常在 torch.no_grad()
的条件下进行。
本节以CIFAR10的一个简单训练为例进行说明
准备数据
PyTorch中提供了很好的数据接口,当然也可以配合其他的数据集与数据加载工具。对于本节中的图像类型的数据,PyTorch的 torchvision
模块提供了很好的帮助,以及 torchvision.transforms
则对图像的预处理与变换提供了很方便的方法。这些在后边的内容中会详细地介绍,这里只给出一个简单的加载数据的例子,并将CIFR10的数据分成训练集和测试集:
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False</