冬日的早晨总是难起,所幸还剩下一两个小时,也不知道今天上午能否写完。
笔者昨日突遭torch和torchvision的版本兼容性背刺,又是检查环境又是查官方文档的推荐版本,结果啥也没查出来,最后清除了一下pytorch缓存再重启vscode却发现奇迹般地work了。笔者觉得这是一种幸运,也不知道下一次还会不会如此幸运,可能正是这些实践过程中的各种意外情况让笔者这个小白成长吧。
昨日笔者一边在还专业课内欠下的债,一边学习李沐老师的课程,两边的新东西一齐涌来,深感脑子要长出来了。而这也让笔者产生了一种吃力且无所适从之感。不过承受压力本身就是一种修行,笔者经过一个晚上的调整,也算是勉强把各种任务理顺了。按照既定的计划完成这些任务,反而能够静下心来,大有一种“何妨吟啸且徐行”的释然。
李沐老师是先简要介绍了Softmax回归,然后开始进行数据集等准备工作。然而笔者目前还不甚理解Softmax回归的原理,准备等学了Softmax回归的代码实现之后合成一篇学习记录。所以就先把准备工作写一写。然而就算是准备工作,对笔者来说亦已经不容易,李沐老师九分钟的时间行云流水般讲完,笔者却在面对代码时有种如坠云雾之感。这无疑是熟悉程度带来的差异。
言归正传,准备图像分类数据集。
导入工具
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()
昨日笔者就是在这一步遭遇被刺,报错显示在 torch 模块中没有 _utils_internal 属性,让笔者十分抓马,找了很久都没发现问题所在,最后靠着清除Pytorch缓存再重启vscode解决问题。其余的没啥好说的,最后一行是设置在 notebook 中显示 SVG 格式的图形,这样可以使图形更加清晰。
读取数据集
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
这段代码把图片数据转换成张量储存,并设置好训练集和测试集。以下是GPT的逐行解释: PyTorch 的
-
transforms.ToTensor():这是一个图像转换的操作,将图像从 PIL(Python Imaging Library) 格式转换为 PyTorch 的 Tensor 类型。它还会将像素值的范围从 [0, 255] 缩放到 [0, 1]。
-
torchvision.datasets.FashionMNIST:这是 torchvision 中用于加载 FashionMNIST 数据集的类。它有一些参数:
-
root: 数据集存储的根目录。
-
train: 指定是否加载训练集,如果为 True,则加载训练集,否则加载测试集。
-
transform: 指定数据预处理的方法,这里使用了 transforms.ToTensor()。
-
download: 指定是否需要下载数据集。如果数据集不存在于 root 目录,会自动下载。
mnist_train 和 mnist_test:这是加载后的训练集和测试集对象。
综合起来,这段代码的作用是加载 FashionMNIST 数据集,将图像转换为 PyTorch 的 Tensor 类型,并进行一些预处理,使得数据集可以用于训练和测试深度学习模型。
查看训练集和测试集规模:
len(mnist_train), len(mnist_test)
查看图像基本信息(通道数,高度,宽度):
mnist_train[0][0].shape
其中第一个[]中的数字是选择训练集中的某个数据,笔者尝试了一下,修改第一个[]中数字的内容,得到的输出是一样的。但是修改第二个[]中数字时却遭遇了报错,这说明笔者还没有完全理解代码的含义,所以还是请GPT解释一番:
在 PyTorch 的 DataLoader 中,mnist_train[0] 返回的是一个包含图像和标签的元组 (image, label),其中 image 是图像数据,label 是对应的标签。
mnist_train[0][0] 获取的是这个元组中的图像部分,而 mnist_train[0][0].shape 是获取图像的形状。如果将第二个 [0] 的数字改为其他数字,就会访问元组中不存在的索引,因此会导致 IndexError。
如果你想访问不同索引的图像,可以尝试使用不同的索引值,确保索引值在数据集范围内。例如,mnist_train[1][0].shape 获取第二张图像的形状,mnist_train[2][0].shape 获取第三张图像的形状,以此类推。
那么问题就很明确了。label不存在shape方法,但是我们可以直接输出label看看:
mnist_train[0][1],mnist_train[0][0].shape
可以看到,label处储存的是一个整数。而接下来的这个函数就可以进行数字标签索引和文本名称的转换:
def get_fashion_mnist_labels(labels):
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
注意到最后一行的类型转换int(i),笔者一开始武断地认为label都是整数,所以对操作不解,下面是GPT的解释,令人信服:
在这个函数中,labels 参数代表 Fashion-MNIST 数据集中的标签。尽管在 Fashion-MNIST 数据集中,标签的确是整数,但有时候标签可能以浮点数的形式存在,例如在进行一些数值计算或数据处理时,可能会得到浮点数形式的标签。为了确保在 text_labels[int(i)] 中的索引操作是整数,将 i 明确地转换为整数是一个良好的做法。
这种做法可以避免因为标签是浮点数而导致的潜在问题,确保索引操作总是使用整数值。在这里,虽然你可能看到 i 确实是整数,但为了提高函数的健壮性,保留了对标签进行整数转换的步骤。如果你确保输入的标签总是整数,你也可以省略这个转换。
这个函数可以将图像样本可视化:
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
"""绘制图像列表"""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# 图片张量
ax.imshow(img.numpy())
else:
# PIL图片
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
这个函数初见时令笔者抓瞎,而GPT一开始给笔者的解释居然也是错误的,还好笔者又一次幸运地解决了问题。还是先看GPT的逐行解释(这个版本应该是正确的):
-
figsize = (num_cols * scale, num_rows * scale): 计算图形的大小,基于每行和每列的图像数量以及缩放比例。
-
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize): 使用 Matplotlib 的 subplots 函数创建一个包含 num_rows 行、num_cols 列的子图图形,并获取返回的子图对象数组 axes。下划线 _ 表示我们不关心第一个返回值(一个 Figure 对象),而 axes 是一个包含所有子图轴对象的数组。
-
axes = axes.flatten(): 将二维数组的 axes 展平为一维数组,这样我们就可以通过一维索引来访问每个轴对象。
-
for i, (ax, img) in enumerate(zip(axes, imgs)):: 使用 enumerate 函数迭代图像列表和轴数组,获取每个图像和相应的轴对象。
-
if torch.is_tensor(img):: 检查图像是否为 PyTorch 张量
-
如果是张量,使用 ax.imshow(img.numpy()) 将张量转换为 NumPy 数组并显示在轴对象上。
-
如果不是张量,假设它是 PIL 图像对象,使用 ax.imshow(img) 直接在轴对象上显示 PIL 图像。
6.ax.axes.get_xaxis().set_visible(False) 和 ax.axes.get_yaxis().set_visible(False): 隐藏每个轴对象的 x 轴和 y 轴。
7.if titles: ax.set_title(titles[i]): 如果提供了标题列表 titles,则为每个轴对象设置相应的标题。
8.最后,函数返回轴对象数组 axes
笔者一开始是从'_,'入手的,因为没怎么学过python,所以好奇这种符号的用法。后面发现是因为d2l.plt.subplots函数返回两个值,而我们只需要用第二个,所以就使用了占位符'_,'。接下来看到对axes进行展平为一维数组的操作,自然关心axes本身是什么结构才引起这种操作。然后问了下GPT知道了轴对象的概念:
在 Matplotlib 中,子图(subplots)和轴(axes)是两个相关但不同的概念。
-
子图(Subplots):
-
子图是整个图形中的一个区域,可以包含一个或多个轴。子图是图形的最外层容器,可以包含多个轴对象。
-
plt.subplots() 函数用于创建包含多个子图的图形。
-
轴(Axes):
-
轴是子图中的一个独立绘图区域,包含了坐标轴、图表元素等。一个子图可以包含一个或多个轴。
-
在 Matplotlib 中,轴对象是实际进行绘图的主要对象,它包含了很多方法用于绘制图形。
在代码中,plt.subplots() 返回一个包含所有子图轴对象的数组,你可以使用数组的索引来访问特定的轴对象。例如,如果你使用 axes[0, 0],你将访问第一行第一列的轴对象。
接下来简单用代码测试一下,axes的确是一个2维数组:
import numpy as np
# 创建一个包含2行3列子图的画布
fig, axes = d2l.plt.subplots(2, 3)
# 打印类型和形状
print(f'Type of axes: {type(axes)}')
print(f'Shape of axes: {np.shape(axes)}')
现在可以发现这些操作的确是合理的。接下来令笔者感兴趣的是enumerate函数的用法:
当你在Python中使用enumerate函数时,它允许你同时遍历列表(或其他可迭代对象)的元素和它们的索引。在这个循环中: for i, (ax, img) in enumerate(zip(axes, imgs)):
-
enumerate(zip(axes, imgs)):这部分是使用zip将axes和imgs打包在一起,创建一个迭代器,其中每个元素都是(ax, img)的元组。enumerate函数用于在迭代过程中获取元素的索引和值。
-
i:这是索引,表示当前迭代的位置。
-
(ax, img):这是元组解包,表示当前迭代的值,其中ax是轴对象,img是图像。
这样,通过enumerate,你可以在每次迭代中同时获得轴对象和图像,而不仅仅是图像或索引。这对于在循环体内根据索引操作多个变量或对象很有用。在这个例子中,它使得可以方便地将图像显示在相应的子图中,并进行其他自定义操作。
现在简单看看前几个样本:
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
第一行的用法已经在上一节接触过了,这里就是用迭代器并返回第一个批次。令笔者感兴趣的是第二行最后加的分号,因为笔者刚刚学习python时,花了一段时间才适应了不加分号,所以对分号有点敏感,结果就把分号去掉,发现多输出了这么一段:
以下是GPT的解释:
1.在交互式环境(如Jupyter Notebook)中,每个代码单元的最后一行的结果会被自动打印出来,除非这个结果被分号(;)结尾。这是因为分号在Python中被用作语句结束符,它告诉解释器不要打印这个语句的结果。
如果你在代码的最后使用分号,就相当于告诉解释器“不要输出这个语句的结果”。这对于一些场景很有用,例如你执行了一个函数或语句,但不希望在屏幕上看到输出。
如果你没有使用分号,解释器会默认打印出最后一行语句的结果。这对于交互性很有帮助,因为你可以立即看到代码执行的结果。
所以,添加或去掉分号主要是为了控制输出是否显示在屏幕上,而不影响代码本身的执行。
2.在 Jupyter Notebook 中,当你执行一个代码单元时,它会显示代码单元最后一行的输出。如果这个输出是一个对象,Jupyter Notebook 会调用该对象的 _repr_html_ 方法来获取 HTML 表示,并将其显示在代码单元的输出区域。
在你的代码中,当你使用分号 ; 时,虽然抑制了默认的输出,但代码仍然在执行,show_images 函数仍然返回了一个包含子图对象的数组 axes。由于有分号,Jupyter Notebook 不会显示这个返回值,而是调用 axes 数组中每个子图对象的 _repr_html_ 方法,将它们的 HTML 表示堆叠在一起显示出来。
因此,分号抑制了默认输出,但并没有阻止 Jupyter Notebook 调用 _repr_html_ 方法来渲染子图对象的 HTML 表示。
这就解释了为什么去掉分号会多出这么一段,根据前面对函数的定义,这其实是打印了axes本身的信息。
讲点题外话,今天由于中途经历点事情,再回来开始写已经是下午快四点了。结果发现原本自习的课室开始上课了,在讲随机过程的随机数生成,笔者又感觉脑子快长出来了。浅浅地希望自己明年面对这些问题就能够不再抓瞎吧。
读取小批量
batch_size = 256
def get_dataloader_workers():
"""使用4个进程来读取数据"""
return 4
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())
唯一可能引起疑惑的是用4个进程来读取数据,看看GPT解释:
在代码中,num_workers 参数是用来指定数据加载器(DataLoader)使用的子进程数量。具体地说,它表示用于数据加载的并行进程数。在你的代码中,num_workers=get_dataloader_workers() 使用了 get_dataloader_workers 函数来获取并行加载数据时的进程数,该函数返回 4。
通过使用多个进程加载数据,你可以加速数据的加载过程,特别是当数据集很大时。每个进程可以负责加载不同的数据批次,从而提高数据读取的效率,同时减轻主进程的负担。
接下来查看读取数据的时间:
timer = d2l.Timer()
for X, y in train_iter:
continue
f'{timer.stop():.2f} sec'
这段代码使用了 d2l.Timer() 来计算一个代码块的执行时间。具体步骤如下:
-
timer = d2l.Timer() 创建了一个计时器对象。
-
for X, y in train_iter: 是一个循环,用于遍历 train_iter 数据加载器的所有批次。每个批次都包含一组输入数据 X 和相应的标签 y。
-
continue 关键字表示在当前循环中什么都不做,直接进入下一次迭代。
-
f'{timer.stop():.2f} sec' 使用了计时器的 stop() 方法,该方法返回从计时器开始到当前时刻的时间间隔,并将结果格式化为字符串。:.2f 表示保留两位小数。这个字符串表示执行整个循环所花费的时间。
总的来说,这段代码测量了 train_iter 数据加载器迭代一次(即遍历所有批次)所需的时间,并将结果以字符串形式打印出来。
这就有意思了。笔者更换了上一个代码块内的进程数,然后比较了不同进程数下读取数据所需要的时间,发现时间随着进程数增加具有先减后增的趋势,并且在进程等于4时大约达到最小值(甚至改变batch_size大小,4也仍然是最佳的),说明4这个数字果然是提前设计好的。然后就询问了一下GPT给出了一些解释: 在使用多进程读取数据时,理论上增加进程数可以加速数据加载。然而,并不是进程数越多越好,因为多进程之间需要进行通信和协同工作,而这也会带来一些开销。
通常情况下,增加进程数可以在以下情况下提高数据加载效率:
-
I/O密集型任务: 当数据加载的主要瓶颈是I/O(例如从磁盘读取数据)时,多进程可以使CPU在等待数据加载的同时执行其他任务,从而提高效率。
-
多核CPU: 如果机器有多个物理核心,多进程可以并行地利用这些核心,从而提高整体性能。
-
数据加载速度受限: 当数据加载速度受限时,例如在网络数据加载中,增加进程数可以充分利用网络带宽。
然而,进程数增加也伴随着一些问题:
-
进程启动和销毁开销: 每个进程都需要一些时间来启动和销毁,这可能会在小规模的数据加载任务中成为性能瓶颈。
-
进程间通信开销: 进程之间需要进行通信,这会引入一些开销。在某些情况下,这个开销可能会抵消多进程带来的性能提升。
-
资源竞争: 多个进程同时访问共享资源(如内存)可能导致资源竞争,需要进行同步操作,这也会引入一些开销。
因此,选择合适的进程数通常需要根据具体的任务和硬件环境进行调整和优化。在你的实验中,进程数为4时获得最佳性能可能是因为它在利用多核CPU的同时避免了过多的进程启动和销毁开销。
整合所有组件
def load_data_fashion_mnist(batch_size, resize=None):
"""下载Fashion-MNIST数据集,然后将其加载到内存中"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))
这段代码定义了一个函数 load_data_fashion_mnist,该函数用于加载Fashion-MNIST数据集并创建数据加载器:
-
函数签名: def load_data_fashion_mnist(batch_size, resize=None): 这个函数接受两个参数:batch_size表示每个批次的样本数,resize表示是否对图像进行调整大小。resize参数默认为None,表示不进行调整大小。
-
数据转换: trans = [transforms.ToTensor()] if resize: trans.insert(0, transforms.Resize(resize)) trans = transforms.Compose(trans) 这里创建了一个数据转换列表 trans。如果 resize 不为 None,则在转换列表的最前面插入一个 transforms.Resize(resize) 转换。然后,通过 transforms.Compose(trans) 创建一个组合转换,这样可以按照列表中的顺序应用这些转换。
-
加载训练和测试数据集: mnist_train = torchvision.datasets.FashionMNIST( root="../data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST( root="../data", train=False, transform=trans, download=True) 使用 torchvision.datasets.FashionMNIST 下载并加载Fashion-MNIST数据集。传递 root 参数指定数据集下载的路径,train 参数表示加载训练集还是测试集,transform 参数指定数据转换。
-
创建数据加载器: return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()), data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers())) 使用 data.DataLoader 创建训练集和测试集的数据加载器。batch_size 指定每个批次的样本数,shuffle 参数表示是否对数据进行随机洗牌,num_workers 参数表示用于数据加载的进程数,使用 get_dataloader_workers() 函数获取最佳进程数。
这个函数的目的是为了方便地加载Fashion-MNIST数据集,并返回相应的训练集和测试集的数据加载器。按照书上的讲法,这相当于构建了数据迭代器从而可以分批次处理数据。
指定resize参数测试图像大小调整功能:
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
练习
1. 减少`batch_size`(如减少到1)是否会影响读取性能? 2. 数据迭代器的性能非常重要。当前的实现足够快吗?探索各种选择来改进它。 3. 查阅框架的在线API文档。还有哪些其他数据集可用?
第二题相对有意思,看看GPT怎么说:
在代码中,数据迭代器的性能受到 num_workers 参数的影响,该参数表示用于数据加载的进程数。通过使用多进程加载数据,可以在数据加载的同时进行其他处理,提高效率。然而,过多的进程可能导致资源竞争和性能下降。
为了改进数据迭代器的性能,可以尝试以下几种选择:
-
增加或减少 num_workers: 调整 num_workers 参数,增加或减少进程数,观察对性能的影响。通常,num_workers 设置为计算机的核心数是一个不错的起点,但具体的最佳值可能因计算环境而异。 pythonCopy code def get_dataloader_workers(): """根据计算机配置返回合适的进程数""" # 返回适当的进程数,可以根据实际情况进行调整 return 4
-
使用更高级的数据加载器: 考虑使用 torch.utils.data.DataLoader 的不同实现,例如 torch.utils.data.DataLoader 中的 WorkerManager 或 torch.utils.data.Dataloader 的 DataLoaderIter。不同实现可能对性能有一定的影响。
-
调整批次大小: 尝试调整批次大小,观察对性能的影响。较大的批次大小可能能够更有效地利用计算资源。
-
使用 GPU: 如果计算机支持 GPU,并且数据集较大,考虑将数据加载到 GPU 上进行处理。这可以通过将模型和数据移动到 GPU 上来实现。
-
异步数据加载: 使用异步数据加载库,例如 dataloader,可以进一步提高数据加载效率。
兜兜转转终于写完了,不禁令笔者对日后的更新感到一丝艰难,不过或许是刚刚开始,打好基础,或许以后效率可以越来越高。如有任何错漏之处,还望读者不吝赐教,感激不尽。