本文参考PyTorch官网的教程,分为五个基本模块来介绍PyTorch。为了避免文章过长,这五个模块分别在五篇博文中介绍。
Part1:PyTorch简单知识
Part2:PyTorch的自动梯度计算
Part3:使用PyTorch构建一个神经网络
Part4:训练一个神经网络分类器
Part5:数据并行化
本文是关于Part4的内容。
Part4:训练一个神经网络分类器
前面已经介绍了定义神经网络,计算损失和更新权重,这里介绍训练神经网络分类器。
1关于数据
通常,当你需要处理图像、文本、饮品或者视频数据,你可以使用标准的python包将数据导入到numpy的array中。之后,你可以将array转换到torch.*Tensor。
(1)对于图像,Pillow、OpenCV等包非常有用。
(2)对于音频,scipy和librosa等包非常好。
(3)对于文本,原始Python或基于Cython的加载,或者NLTK和SpaCy都是有用的。
尤其对于视觉,我们创建了一个叫做torchvision的包,包含了对于常用数据集(如ImageNet,CIFAR10,MNIST等)的数据加载器和对于images、viz的数据转换器,torchvision.datasets和torch.utils.data.DataLoader。
在该教程中,我们使用CIFAR10数据集。它含有这些类:‘airplane’,‘automobile’,‘bird’,‘cat’,‘deer’,‘dog’,‘frog’,‘horse’,‘ship’,‘truck’。这些图像的尺寸是3*32*32,即3通道的彩色图像,尺寸为32*32。