图像分类数据集中最常用的是手写数字识别数据集MNIST。但大部分模型在MNIST上的分类精度都超过了95%。为了更直观地观察算法之间的差异,我们将使用一个图像内容更加复杂的数据集Fashion-MNIST(这个数据集也比较小,只有几十M,没有GPU的电脑也能吃得消)。
我们将使用torchvision包,它是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:
torchvision.datasets
: 一些加载数据的函数及常用的数据集接口;torchvision.models
: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;torchvision.transforms
: 常用的图片变换,例如裁剪、旋转等;torchvision.utils
: 其他的一些有用的方法。
Fashion-MNIST中一共包括了10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数可以将数值标签转成相应的文本标签。
import torchimport torchvisionimport torchvision.transforms as transformsimport matplotlib.pyplot as pltimport timeimport sysfrom IPython import displaymnist_train=torchvision.datasets.FashionMNIST(root='./datasets/FashionMNIST',train=True,download=False,transform=transforms.ToTensor())mnist_test=torchvision.datasets.FashionMNIST(root='./datasets/FashionMNIST',train=False,download=False,transform=transforms.ToTensor())print(type(mnist_train))print(len(mnist_train),len(mnist_test))feature,label=mnist_train[0]print(feature.shape,label)def use_svg_display(): # 用矢量图显示 display.set_matplotlib_formats('svg')def set_figsize(figsize=(3.5, 2.5)): use_svg_display() # 设置图的尺寸 plt.rcParams['figure.figsize'] = figsizedef get_fashion_mnist_labels(labels): text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels]def show_fashion_mnist(images, labels): use_svg_display() _, figs = plt.subplots(1, len(images), figsize=(12, 12)) for f, img, lbl in zip(figs, images, labels): f.imshow(img.view((28, 28)).numpy()) f.set_title(lbl) f.axes.get_xaxis().set_visible(False) f.axes.get_yaxis().set_visible(False) plt.show()X, y = [], []for i in range(10): X.append(mnist_train[i][0]) y.append(mnist_train[i][1])show_fashion_mnist(X, get_fashion_mnist_labels(y))
分类结果: