一、关于F-MNIST数据集的一些操作
import matplotlib.pyplot as plt
import torch
import torchvision #pytorch对于计算机视觉模型实现库
from torch.utils import data
from torchvision import transforms #对数据进行操作的步骤
from d2l import torch as d2l
d2l.use_svg_display() #用svg显示图片,清晰度高
#通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()#最简单的预处理
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
# 训练数据集true,数据类型是pytorch的tensor而不是图片
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
#print(len(mnist_train), len(mnist_test))看一下训练集和测试集多长
#mnist_train[0][0].shape看看训练集第一个训练集的图片(1是标号)的形状
def get_fashion_mnist_labels(labels): #@save
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
#imgs:图像列表,可以是 PyTorch 张量或 PIL 图像
#num_rows 和 num_cols:决定了网格的行数和列数,即图像展示的布局
#titles(可选参数):图像标题列表,可以为每幅图像提供标题
#scale:缩放比例,控制图像的大小
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""绘制图像列表"""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
#axes 是一个包含子图的数组,d2l.plt.subplots 函数会创建一个 num_rows 行,num_cols 列的网格,每个网格位置用于显示一张图像
#这里的-,表示占位符,因为subplots的输出有一个figure没有使用表明我们并不关心这个返回值
axes = axes.flatten()
#将二维的 axes 数组展开成一维数组,以便后面更方便地遍历处理每个子图
for i, (ax, img) in enumerate(zip(axes, imgs)):
#使用 enumerate 对 axes 和 imgs 进行同步遍历,每次分别获得一个轴 ax 和一张图像 img
if torch.is_tensor(img):
# 图片张量
ax.imshow(img.numpy())
#判断 img 是否是 PyTorch 张量(使用 torch.is_tensor 函数),
#如果是,则需要调用 img.numpy() 将其转化为 NumPy 数组以便 imshow 可以显示它
else:
# PIL图片
ax.imshow(img)
#如果 img 不是张量,则假定它是一个 PIL 图像,直接用 imshow 显示
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
#如果 img 不是张量,则假定它是一个 PIL 图像,直接用 imshow 显示
if titles:
ax.set_title(titles[i])
#如果有 titles 提供,则为每个图像设置对应的标题
return axes
#X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))看一看训练集前几个样本及其标签
#show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))
#读取小批量
batch_size = 256
def get_dataloader_workers(): #@save
"""使用4个进程来读取数据"""
return 4
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())
#看一下读取训练数据的所需的时间
timer = d2l.Timer()
for X, y in train_iter:
continue
f'{timer.stop():.2f} sec'
二、定义load_data_fashion_mnist函数
def load_data_fashion_mnist(batch_size, resize=None): #@save
"""下载Fashion-MNIST数据集,然后将其加载到内存中"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))
#我们通过指定参数来测试函数的图像大小调整功能
rain_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break