前言
Fashion-MNIST由10个类别的图像组成,每个类别由训练集中的6000张图像和测试数据集中1000张图像组成。
定义图像转换方法
- 改变大小
- 转换成张量
trans=[transforms.Resize(64),transforms.ToTensor()]
trans=transforms.Compose(trans)
可选择是否对图片尺寸进行更改
trans=[transforms.ToTensor()]
if resize:
trans.insert(0,transforms.Resize(resize))
利用框架内置函数下载并读取数据集
mnist_train=torchvision.datasets.FashionMNIST(
root="../data",train=True,transform=trans,download=True
)
mnist_test=torchvision.datasets.FashionMNIST(
root="../data",train=False,transform=trans,download=True
)
数据格式为mnist_train[0][0]为图片,mnist_train[0][1]为标签
print(mnist_test[0][0].shape)
print(mnist_test[0][1])
torch.Size([1, 64, 64])
9
使用内置迭代器读取小批量
train_iter = data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=0)