准备所需导入的库
代码:
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
# 将 d2l 库生成的图形的显示格式设置为 SVG(可缩放矢量图形)格式
d2l.use_svg_display()
一、读取多类分类图片数据集
Fashion-MNIST数据集
通过内置函数将Fashion-MNIST数据集下载并读取到内存中
## 使用框架内置函数将Fashion-MNIST数据集下载并读取到内存中
# 通过ToTensor实例将图像数据从PIL类型变换为32位浮点数
# 除以255使得所有像素的值在0到1之间
trans = transforms.ToTensor()
# 定义训练数据集和测试数据集,从网上下载
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True
)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True
)
# 显示数据集的数据量(图片张数)
print(len(mnist_train))
print(len(mnist_test))
· Fashion-MNIST由10个类别的图像组成, 每个类别由训练数据集(train dataset)中的6000张图像 和测试数据集(test dataset)中的1000张图像组成。 因此,训练集和测试集分别包含60000和10000张图像。 测试数据集不会用于训练,只用于评估模型性能。
二、可视化数据样本
将图片数据集下载下来后,要对数据集进行可视化,将图片显示出来。
Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。
1. 定义一个函数,用于在数字标签索引及其文本名称之间进行转换。
def get_fashion_mnist_labels(labels): #@save
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
2. 创建一个可视化函数,将PIL图片连同标签显示出来
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""绘制图像列表"""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# 图片张量
ax.imshow(img.numpy())
else:
# PIL图片
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
d2l.plt.show()
数据集中前几个样本的图像和对应标签:
三、读取小批量数据
为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。 回顾一下,在每次迭代中,数据加载器每次都会读取一小批量数据,大小为batch_size
。 通过内置数据迭代器,我们可以随机打乱所有样本,从而无偏见地读取小批量。
使用4进程读取数据,也可以使用8进程等(根据CPU选择)。
代码:
batch_size = 256
def get_dataloader_workers(): #@save
"""使用4个进程来读取数据"""
return 4
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())
· DataLoader表示一个数据加载器对象,用于批量加载数据。它接收训练数据集、批次大小、是否打乱数据以及读取数据的进程数量等参数。
· 在这里,mnist_train 是要加载的训练数据集,batch_size 是每个批次中的样本数量,shuffle=True 表示在每个 epoch(训练轮次)开始时打乱数据集,然后使用4个进程来读取数据。
· 创建数据加载器对象,可以方便地迭代访问训练数据集中的批次数据,以供模型训练使用。
查看读取训练数据所需的时间:
timer = d2l.Timer()
for X, y in train_iter:
continue
f'{timer.stop():.2f} sec'
· 使用 for 循环遍历 train_iter,每次迭代获取一个批次的训练数据,其中 X 是输入数据,y 是对应的标签数据。在这个例子中,由于循环体内没有实际的操作,使用 continue 关键字跳过当前迭代,即不执行任何操作。timer则用于计时。
· 接下来,使用 timer.stop() 方法停止计时,并使用 f-string 格式化字符串的方式将计时结果转换为保留两位小数的字符串。
四、整合图片数据集组件
定义load_data_fashion_mnist
函数,用于获取和读取Fashion-MNIST数据集。 这个函数返回训练集和验证集的数据迭代器。 此外,这个函数还接受一个可选参数resize
,用来将图像大小调整为另一种形状。
# 整合所有组件,创建图像分类数据集
def load_data_fashion_mnist(batch_size, resize=None): #@save
"""下载Fashion-MNIST数据集,然后将其加载到内存中"""
trans = [transforms.ToTensor()]
# 如果resize参数存在,则将 transforms.Resize(resize) 实例插入到 trans 列表的第一个位置
# 这样,可以确保在将图像数据转换为张量之前,先对图像进行调整大小的操作。
if resize:
trans.insert(0, transforms.Resize(resize))
# 使用 transforms.Compose 函数将 trans 列表中的实例组合成一个变换序列——方便一次性对图像数据进行多个变换操作
trans = transforms.Compose(trans)
# 下载数据集
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
# 返回两个数据加载器对象,第一个是训练数据,第二个是测试数据
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))
# 获取数据加载器对象,并指定图像大小为64x64
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
通过该函数,我们可以很方便地得到图像分类数据集,获取训练数据和测试数据的数据加载器,以供后续模型训练使用。