实现对CIFAR-10的分类,步骤如下:
- 使用torchvision加载并预处理CIFAR-10数据集
- 定义网络
- 定义损失函数和优化器
- 训练网络并更新网络参数
- 测试网络
1. CIFAR-10数据加载及预处理
import torchvision as tv
import torch as t
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage() #将Tensor转成Image,方面可视化
# 第一次运行程序torchvision会自动下载CIFAR-10数据集。
# 如果已经下载有CIFAR-10,可通过root参数指定
# 定义对数据的预处理,Compose这个类是用来管理各个transform的
transform = transforms.Compose([transforms.ToTensor(), # 转为Tensor
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) # 归一化
# 训练集
trainset = tv.datasets.CIFAR10(root='D:\\Workspace\\Python\\CIFAR-10\\', train = True, download=True, transform=transform)
trainloader = t.utils.data.DataLoader(trainset, batch_size=4, shuffle=True,num_workers=2)
#测试集
testset = tv.datasets.CIFAR10(root='D:\\Workspace\\Python\\CIFAR-10\\', train = False, download=True, transform=transform)
testloader = t.utils.data.DataLoader(testset, batch_size=4, shuffle=False,num_workers=2)
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
Files already downloaded and verified
Files already downloaded and verified
1.1 ToTensor类是实现:Convert a PIL Image or numpy.ndarray to tensor的过程,在PyTorch中常用PIL库来读取图像数据,因此这个方法相当于搭建了PIL Image和Tensor的桥梁。另外要强调的是在做数据归一化之前必须要把PIL Image转成Tensor,而其他resize或crop操作则不需要。