pytorch学习(3)训练一个图片分类器
与上一章不同,这次我们基于CIFAR10数据集训练一个完整的图片分类器,一般有如下步骤:
用torchvision导入CIFAR10数据集并初始化
torchvision是计算机视觉中很重要的软件包,其封装了大量的数据集、预训练网络等,用起来十分方便:
下载CIFAR10数据集:
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform