文章目录
训练分类器
到目前为止,我们已经了解了如何定义NN,计算loss以及更新网络的权重。那我们接下来再来看一下数据的处理
data
当我们需要处理图像、文本、或音/视频文件时,通常可以找到一些python包来载入数据到numpy数组中。然后我们可以将这个数组再转化为torch.*Tensor
。
- 对图片来说,可以用
Pillow
,OpenCV
- 对音频来说,可以用
scipy
,librosa
- 对文本来说,可以用python或Cython自带的载入,或者是
NLTK
和SpaCy
特别地,对于视觉方面,可以使用torchvision
,其中包括了对Imagenet、CIFAR10、MNIST等常用数据集的数据加载器(data loaders),包括对图片数据变形的操作。即torchvision.datasets
和torch.utils.data.DataLoader
。
在这个教程中,我们将使用CIFAR10数据集。它有如下的分类:airplane、automobile、bird、 cat、 deer、 dog、 frog、horse、 ship、 truck 等。在CIFAR-10里面的图片数据大小是3x32x32,即三通道彩色图,图片大小是32x32像素。
training an image classifier
接下来会逐步做如下操作,其实也就是训练分类器的步骤:
- 通过
torchvision
加载CIFAR10里面的训练和测试数据集,并对数据进行标准化处理 - 定义卷积神经网络
- 定义损失函数
- 利用训练数据训练网络
- 利用测试数据测试网络
1. Loading and normalizing CIFAR10
使用torchvision
可以很方便地加载CIFAR10
import torch
import torchvision
import torchvision.transforms as transforms
torchvision数据集加载完的输出为范围在[0,1]之间的PILImage图片。我们需要将其标准化为范围为[-1,1]之间的张量。下面的Normalize
,param第一项是mean序列,此处为3项,因为有三个channel。第二项是std序列,同样有3项。同时注意此处的转换是out-of-place,它不会改变原有输出。
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
输出:
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10