- 参考:《动手学深度学习》(Pytorch)版 3.5 节
- 注:本文是 jupyter notebook 文档转换而来,部分代码可能无法直接复制运行!
- 图像分类数据集中最常用的是手写数字识别数据集MNIST,但大部分模型在MNIST上的分类精度都超过了95%,为了更直观地观察算法之间的差异,本文介绍一个图像内容更加复杂的数据集 Fashion-MNIST,这个数据集难度比 MNIST 高,但是尺寸并不大,只有几十M,没有GPU的电脑也能吃得消
- 该数据集可以利用
torchvision
包来下载和处理,该包包含以下几个核心模块torchvision.datasets
: 提供加载数据的函数及常用数据集接口;torchvision.models
: 包含常用的模型结构(含预训练模型),如 AlexNet、VGG、ResNet 等;torchvision.transforms
: 提供常用的图片变换方法,例如裁剪、旋转等;torchvision.utils
: 提供其他的一些有用的方法
- 开始介绍前,先导入包
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import time import numpy as np from IPython import display
1. 获取数据集
-
通过
torchvision.datasets.FashionMNIST
方法获取数据集mnist_train = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=True, transform=transforms.ToTensor()) mnist_test = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=False, transform=transforms.ToTensor())
参数说明
-
root
参数指定数据集保存路径 -
train
参数指定获取训练集还是测试集 -
download
参数若设置为True
,则在发现 root 路径下没有数据集时自动从网上下载,若已有数据集则不动作 -
transform = transforms.ToTensor()
使所有数据转换为Tensor
,如果不转换则返回的是 PIL 图片transforms.ToTensor()
将 “尺寸为 H × W × C H \times W \times C H×W×C 且数据位于 [ 0 , 255 ] [0, 255] [0,255] 的PIL图片” 或者 “数据类型为np.uint8
的NumPy数组” 转换为 “尺寸为 C × H × W C \times H \times W C×H×W 且数据类型为torch.float32
且位于[0.0, 1.0]
的Tensor”注意
transforms.ToTensor()
在内的一些关于图片的函数默认输入为uint8
类型,如果不是则可能得到不想要的结果,所以如果用 [ 0 , 255 ] [0,255] [0,255] 的像素值表示图片数据,则一律将其类型设置为uint8
,以免不必要的bug
-
-
这里加载的
mnist_train
和mnist_test
都是torch.utils.data.Dataset
的子类,一些常用方法如下print(type(mnist_train)) print(len(mnist_train), len(mnist_test)) # 用 len() 获取该数据集的大小 feature, label = mnist_train[0] # 通过下标来访问任意样本 print(feature.shape, label) # [Channel , Height , Width] label,注意由于数据集中都是灰度图,通道数为 1 ''' torchvision.datasets.mnist.FashionMNIST 60000 10000 torch.Size([1, 28, 28]) 9 '''
-
Fashion-MNIST中一共包括了10个类别,分别为
- t-shirt(T恤)
- trouser(裤子)
- pullover(套衫)
- dress(连衣裙)
- coat(外套)
- sandal(凉鞋)
- shirt(衬衫)
- sneaker(运动鞋)
- bag(包)
- ankle boot(短靴)
使用以下函数将数值标签列表转成相应的文本标签列表
def get_fashion_mnist_labels(labels): text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels]
-
使用以下函数在一行里绘制多个图像和对应的标签
def show_fashion_mnist(images, labels): display.set_matplotlib_formats('svg') # Use svg format to display plot in jupyter _, figs = plt.subplots(1, len(images), figsize=(12, 12)) for f, img, lbl in zip(figs, images, labels): f.imshow(img.view((28, 28)).numpy()) f.set_title(lbl) f.axes.get_xaxis().set_visible(False) f.axes.get_yaxis().set_visible(False) plt.show()
-
随机显示 10 个样本
X, y = [], [] for i in np.random.randint(0,60000,size = 10).tolist(): X.append(mnist_train[i][0]) y.append(mnist_train[i][1]) show_fashion_mnist(X, get_fashion_mnist_labels(y))
这里我遇到一个报错,请参考 ‘OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program’,我删除了虚拟环境中的
libiomp5md.dll
解决此问题
2. 读取小批量
-
在实践中,数据读取经常是训练的性能瓶颈,
torch.utils
模块提供的DataLoader
方法允许我们方便地使用多进程来加速数据读取 -
mnist_train
是torch.utils.data.Dataset
的子类,所以我们可以将其传入torch.utils.data.DataLoader
来创建一个读取小批量数据样本的DataLoader
实例,在创建时- 通过参数
num_workers
来指定读取数据的进程数量 - 通过
shuffle
参数指定读取时是否打乱
batch_size = 256 if sys.platform.startswith('win'): # 判断操作系统为 windows num_workers = 4 # 使用 4 个进程同时读取 else: num_workers = 0 # 0表示不用额外的进程来加速读取数据 train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
- 通过参数
-
查看读取一遍数据的耗时
start = time.time() for X, y in train_iter: continue print('%.2f sec' % (time.time() - start))
经测试,我的笔记本电脑在不使用多进程加速时耗时 5.88s,使用后减少到 3.18s