softmax回归的从零开始实现:
%matplotlib inline
import torch
import torchvision #pytorch用于计算机视觉的一个库
from torch.utils import data
from torchvision import transforms #导入对数据操作的模具
from d2l import torch as d2l
# --------1--------
# 获取数据集的标签
def get_fashion_mnist_labels(labels): # @save
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor() #对图片进行预处理,转换为tensor格式
# 下载训练集和测试集,并保存
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans,download=True)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans,download=True)
# 输出训练集和测试集的大小
len(mnist_train), len(mnist_test)
# 索引到第一张图片
mnist_train[0][0].shape # 输入图像的通道数、高度和宽度
# --------2--------
# 获取数据集的标签
def get_fashion_mnist_labels(labels): #@save
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_lables[int(i)] for i in labels]
# --------3--------
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""绘制图像列表"""
figsize = (num_cols * scale, num_rows * scale)
# 第1个参数是个图,一般不用;第2个axer类似于图片的索引矩阵(行,列)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) # axes:轴
axes = axes.flatten()
# 遍历生成形如i, (ax, img)形式的enumerate对象
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# 图片张量
ax.imshow(img.numpy())
else:
# PIL图片
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False) #x轴隐藏
ax.axes.get_yaxis().set_visible(False) #y轴隐藏
if titles:
ax.set_title(title[i]) #显示标题
return axes
# ------4--------
# 绘制图表
# 使用next()函数获取批量大小为18的训练集的图像和标签
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
#显示18张图片,宽度为28,长度为28,总共为2行9列
# 绘制两行图片,每一行有9张图片,并获取标签
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
# pylab.show() 查看图表
# ------5--------
# 读取小批量
batch_size = 256
def get_dataloader_workers(): #@save
"""使用4个进程来读取数据"""
return 4
# 训练集需要设置shuffle=True打乱顺序
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())
# ------6--------
# 查看读取时间
timer = d2l.Timer() #调用Timer函数,测试速度
for X, y in train_iter:
continue
f'{timer.stop():.2f} sec' #输出读取数据所用的秒数,精度为2位小数
# ------7--------
# 整合所有组件
def load_data_fashion_mnist(batch_size, resize=None): # @save
"""下载Fashion-MNIST数据集,然后将其加载到内存中"""
# 转换为tensor
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize)) # 0在此处是索引
# compose整合步骤
trans = transforms.Compose(trans)
# 下载训练集和测试集,将小批量样本返回到train_iter中,用于之后的训练
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))
# ------8--------
# 通过resize参数来测试load_data_fashion_mnist函数的图像大小调整功能
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
[text_labels[int(i)] for i in labels]
# 遍历labels列表中的每个元素i。
# 对于每个i,使用int(i)将其转换为整数。
# 使用转换后的整数索引去访问text_labels列表中的对应元素。
# 将所有访问到的元素收集到一个新的列表中。
import matplotlib.pyplot as plt
# 假设num_rows, num_cols, 和 figsize 已经定义好了
fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
# 现在你可以使用fig来访问整个图表对象(例如,设置标题),
# 使用axes来访问和操作子图(Axes)对象
# 如果num_rows和num_cols都大于1,axes将是一个二维数组
# 你可以通过axes[i, j]来访问第i行第j列的子图
# 示例:遍历所有子图并设置标题
for i in range(num_rows):
for j in range(num_cols):
axes[i, j].set_title(f'Axes {i}, {j}')
# 显示图表
plt.show()
for i, (ax, img) in enumerate(zip(axes, imgs))
zip(axes, imgs):
# zip函数是Python中的一个内置函数,用于将多个可迭代对象(如列表、元组等)作为参数,并将它们“打包”成一个迭代器,该迭代器生成一个元组序列,其中每个元组包含来自每个参数序列的对应位置的元素。
# 在这个例子中,axes和imgs是两个可迭代对象,分别包含了多个子图(Axes)对象和图像(或图像数据)对象。zip(axes, imgs)会生成一个迭代器,该迭代器按顺序产出(ax, img)对,其中ax是axes列表中的一个子图对象,img是imgs列表中与ax相对应位置的一个图像对象。
enumerate(...):
# enumerate函数是另一个内置函数,用于将一个可迭代对象(如列表、元组、字符串等)组合为一个索引序列,同时列出数据和数据下标,一般用在for循环当中。
# 在这里,enumerate(zip(axes, imgs))会遍历zip(axes, imgs)生成的迭代器,并为每个(ax, img)对提供一个索引i。这样,在for循环的每次迭代中,你都会得到一个元组(i, (ax, img)),其中i是当前迭代的索引(从0开始),(ax, img)是当前位置上的子图对象和图像对象。
for i, (ax, img) in enumerate(zip(axes, imgs))::
# 这行代码将上述两个概念结合在一起,通过for循环遍历zip(axes, imgs)生成的迭代器,并在每次迭代中解包得到索引i和子图-图像对(ax, img)。
# 这使得你可以在循环体内使用i来访问当前迭代的索引(例如,用于日志记录或条件判断),同时使用ax和img来分别操作当前的子图对象和图像对象(例如,在子图ax上显示图像img)。
# 这种用法在处理多个子图和相应的图像数据时非常有用,因为它提供了一种简洁而强大的方式来遍历这些元素对,并对它们执行操作。
continue在循环语句中的使用:
numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for number in numbers:
if number % 2 == 0: # 检查数字是否为偶数
continue # 如果是偶数,则跳过当前迭代
print(number) # 打印奇数
for i in range(1, 4): # 外层循环
for j in range(1, 10): # 内层循环
if j % i == 0: # 如果j能被i整除
continue # 这里只会影响内层循环,跳到内层循环的下一个迭代
print(f"{i}x{j}={i*j}", end=' ')
print() # 打印换行符,以便区分不同的i值