Pytorch实现CIFAR-10分类的5个步骤
tensorvision包中自带常用的视觉数据集,其中就包括CIFAR-10。Tutorial中将网络的训练分为了5个步骤:
- 准备数据:下载CIFAR-10并归一化
- 定义CNN
- 定义损失函数
- 在training set上训练CNN
- 在test set上测试CNN
1. 准备数据:下载CIFAR-10并归一化
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
数据处理使用torch.ultis.data
类与torchvision.transforms
类。
其中,transforms.Compose()
是将各种数据变换组合起来使用,由各种变换构成的列表。
torchvision.datasets.CIFAR10()
可以使用自己下载好的,解压后放在./data
内,并设置download=False
。
torch.utils.data.DataLoader() 类
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
torch.utils.data.DataLoader
类是数据加载器,它组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
显示部分图像
下面的代码用来显示部分图像:
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg