图像分类数据集
MNIST数据集是图像分类中广泛使用的数据集之一,是对手写数字的识别,大概86年提出的
但作为基准数据集过于简单。
我们将使用类似但更复杂的Fashion-MNIST数据集
%matplotlib inline
import torch
import torchvision #是pytorch对于计算机视觉实现的一个库
from torch.utils import data#方便读取数据一些小批量的函数
from torchvision import transforms#对数据操作的模具导入进来
from d2l import torch as d2l#将一些函数实现好之后存在d2l里面
d2l.use_svg_display()#用svg来显示我们的图片,这样子清晰度高一点
读取数据集
可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0到1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(#从torchvision.datasets里面把FashionMNIST拿到
root="../data", train=True, transform=trans, download=True)#"../data"下载到上级目录的data下面,train=True下载的是训练数据集,transform=trans是说我们拿出来之后我们需要得到的是一个pytorch的tensor而不是一堆图片,download=True意思是我们默认从网上下载
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)#train=False下载的是测试集
Fashion-MNIST由10个类别的图像组成, 每个类别由训练数据集(train dataset)中的6000张图像 和测试数据集(test dataset)中的1000张图像组成。
因此,训练集和测试集分别包含60000和10000张图像。 测试数据集不会用于训练,只用于评估模型性能。
len(mnist_train), len(mnist_test)(60000, 10000)
(60000, 10000)
每个输入图像的高度和宽度均为28像素。 数据集由灰度图像组成,其通道数为1。
将高度ℎ像素、宽度𝑤像素图像的形状记为ℎ×𝑤或(ℎ,w)。
mnist_train[0]