本代码基于《动手学深度学习》Pytorch版,第三章线性回归网络,第五节图像分类数据集。对代码进行修改,增加注释,供学习使用。
导入相关库
import matplotlib_inline
import torchvision
import matplotlib.pyplot as plt
import torch
import time
import numpy as np
plt.rcParams['font.sans-serif'] = ['SimHei']
# 显示中文,在Windows系统上,选择SimHei(黑体)或其他中文字体,将其设置为Matplotlib的默认字体
# 设置JupyterNotebook中Matplotlib默认图像显示格式为SVG,让图表更易嵌入网页,并可在各种分辨率和设备上呈现出高质量效果
# SVG(ScalableVectorGraphics,可缩放矢量图形)基于XML的矢量图形格式,可缩放和调整大小而不失真
# SVG图像具有更高的质量和可扩展性,在需更清晰的图像或需要缩放图像时使用
# 在JupyterNotebook中,默认Matplotlib生成的图像会以PNG格式显示
def use_svg_display():
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
# backend_inline是JupyterNotebook设置Matplotlib后端的模块,使生成的Matplotlib图像直接嵌入笔记本中,而不是单独窗口中显示
# set_matplotlib_formats()设置Matplotlib图像显示格式,允许为Matplotlib生成的图像指定所需的输出格式
# 可接受其他参数,quality(设置JPEG输出格式图像质量),dpi(设置图像每英寸点数)等
# 设置图像显示格式
use_svg_display()
加载并处理Fashion-MNIST数据集
# 加载并处理Fashion-MNIST数据集
trans = torchvision.transforms.ToTensor()
# ToTensor()数据预处理,将输入数据(如PIL图像或NumPy数组)转换为张量
# 将图像像素值从0-255范围转换为0-1范围(浮点数)
# 将图像的通道顺序从HWC(高度、宽度、通道)转换为CHW(通道、高度、宽度),这是PyTorch常用的格式
# 通常与其他数据预处理操作一起使用,以便将图像数据准备为神经网络训练所需的格式
train = torchvision.datasets.FashionMNIST(root = 'C:\\Users\\kongbai\\study\\数据集\\fashionMNIST', train = True, transform = trans, download = True)
# FashionMNIST()加载和处理Fashion-MNIST数据集
# root:数据集的存储路径,默认为./data
# train:是否加载训练数据,默认为True
# transform:用于数据预处理的变换,默认为None
# target_transform:用于目标(标签)预处理的变换,默认为 None
# download:是否下载数据集,如果数据集不存在,则下载,默认为True
# 默认下载数据集,如果数据集已存在,则不会再次下载,如果需要强制下载,可将download参数设置为True
test = torchvision.datasets.FashionMNIST(root = 'C:\\Users\\kongbai\\study\\数据集\\fashionMNIST', train = False, transform = trans, download = True)
print(len(train))
print(len(test))
print(train[0][0].shape)
print(test[0][0].shape)
运行结果
60000
10000
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
返回Fashion-MNIST数据集的文本标签
def get_label(labels):
label = ['T恤', '裤子', '套衫', '连衣裙', '外套', '凉鞋', '衬衫', '运动鞋', '包', '短靴']
return [label[int(i)] for i in labels]
# [expression for item in iterable if condition]
# 列表推导式是Python的语法,根据已有的列表或其他可迭代对象快速生成新的列表,让代码更加简洁易读,提高代码执行效率
# 可用于生成新列表,可嵌套使用,以实现更复杂的逻辑
# 过度使用可能会导致代码难以阅读和维护,在实际编程中要根据具体情况权衡使用
# expression对item的操作表达式,生成新列表的元素
# item可迭代对象中的每个元素
# iterable可迭代对象
# condition可选的条件表达式,过滤可迭代对象中的元素
可视化Fashion-MNIST数据集图像
def show(imgs, rows, cols, titles = None, scale = 1.5):
_, axes = plt.subplots(rows, cols, figsize = (cols * scale, rows * scale))
axes = axes.flatten()
# flatten()将多维数组展平为一维数组,返回新的数组,原始数组不被修改
# 如果想要在原地修改数组,可使用ravel(),与flatten()功能类似,返回原始数组的一个视图(view)而不是副本
# 如果使用的是Python列表,需要使用其他方法来展平列表,可使用列表推导式
for i, (ax, img) in enumerate(zip(axes, imgs)):
# enumerate(iterable, start = 0)遍历可迭代对象,同时返回每个元素的索引和值
# 返回枚举对象,可使用next()或将其转换为列表、元组等数据结构
# iterable可迭代对象
# start可选参数,指定索引的起始值,默认为0
# zip(*iterables)将多个可迭代对象组合成一个新的可迭代对象
# 按顺序从每个可迭代对象中取出相同位置的元素,然后将这些元素组合成新的元组,最后所有这些元组会被组合成新的可迭代对象
# iterables一个或多个可迭代对象
# zip()会按最短的可迭代对象进行组合,如果输入的可迭代对象长度不同,较长的可迭代对象中多余的元素将被忽略
if torch.is_tensor(img):
# is_tensor()检查对象是否为张量
ax.imshow(img.numpy())
# imshow(image, cmap = None)显示图像,可接受多种类型的输入,并将其显示在当前的Axes对象上
# 默认使用matplotlib的颜色映射显示图像,灰度图像通常没有问题,彩色图像,可能需要设置cmap参数为None,以便显示颜色
# 还可接受其他参数,控制图像的显示效果
# numpy()将张量转换为NumPy数组,要求张量位于CPU上,如果张量位于GPU上,需先将其移动到CPU上,然后再使用
# 返回的NumPy数组与原始PyTorch张量共享内存,如果修改NumPy数组,原始张量的值也会发生变化,反之亦然
# 如果希望创建独立的NumPy数组副本,可使用numpy_array.copy()
else:
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
# axes属性,访问当前坐标轴对象
# get_xaxis()获取当前坐标轴的x轴对象
# set_visible()设置对象可见性,可应用于各种matplotlib对象,如轴(Axes)、刻度(Ticks)、图例(Legend)等
ax.axes.get_yaxis().set_visible(False)
# get_yaxis()获取当前坐标轴的y轴对象
if titles:
ax.set_title(titles[i])
return axes
x, y = next(iter(torch.utils.data.DataLoader(train, batch_size = 10)))
# next()获取迭代器的下一个元素
# 当迭代器中没有更多元素时,将引发StopIteration异常,可使用try-except处理
# 接受可选的第二个参数,该参数用作默认值,当迭代器中没有更多元素时,将返回此默认值,而不引发异常
# iter()从可迭代对象创建一个迭代器,允许遍历可迭代对象的元素,一次一个
# 迭代器只能遍历一次,已经遍历了所有元素后,再次尝试使用next()函数将引发StopIteration异常
# DataLoader()批量加载数据集,将数据集分成多个批次,并支持多线程/多进程数据加载,从而加速数据加载过程
# dataset(Dataset)要加载的数据集对象,此对象需实现__len__()和__getitem__()方法
# batch_size(int)可选,每批次的大小,默认为1
# shuffle(bool)可选,是否在每个epoch开始时打乱数据,默认为False
# sampler(Sampler)可选,自定义采样器,决定从数据集中提取哪些样本,如果提供sampler,则shuffle参数将被忽略
# num_workers(int)可选,数据加载的子进程数,默认为0,表示不使用多进程
# collate_fn(callable)可选,合并批次数据的函数,默认使用torch.utils.data._utils.collate.default_collate
# pin_memory(bool)0可选,如果为True,则将数据加载到固定内存中,以便更快地传输到GPU,默认为False
# drop_last(bool)可选,如果为True,则丢弃最后一个不完整的批次,默认为False
# timeout(numeric)可选,等待从工人进程接收数据的超时值(以秒为单位),默认为0
# worker_init_fn(callable)可选,在每个工人进程中调用,初始化工人进程,默认不做任何操作
show(x.reshape(10, 28, 28), 2, 5, titles = get_label(y));
运行结果
计时器
class Timer:
# 记录多次运行时间
def __init__(self):
self.times = []
self.start()
# 启动计时器
def start(self):
self.tik = time.time()
# 停止计时器并将时间记录在列表中
def stop(self):
self.times.append(time.time() - self.tik)
return self.times[-1]
# 返回平均时间
def avg(self):
return sum(self.times) / len(self.times)
# 返回时间总和
def sum(self):
return sum(self.times)
# 返回累计时间
def cumsum(self):
return np.array(self.times).cumsum().tolist()
构建数据迭代器读取小批量数据并计算读取所需时间
batch = 256
train_iter = torch.utils.data.DataLoader(train, batch, shuffle = True, num_workers = 4)
timer = Timer()
for x, y in train_iter:
continue
print(f'{timer.stop():.2f} sec')
运行结果
3.70 sec
整合组件
def dataset(batch, resize = None):
trans = [torchvision.transforms.ToTensor()]
if resize:
trans.insert(0, torchvision.transforms.Resize(resize))
# insert()是Python列表的方法,在列表的指定位置插入一个元素,接受两个参数:插入元素的索引,要插入的元素
# 如果指定的索引超出列表的范围,将引发IndexError异常
# Resize()调整图像的大小,接受一个表示新尺寸的元组作为参数,表示新图像的宽度和高度
# 返回新的调整大小的图像对象,而不修改原始图像对象
# 还支持额外的参数,例如interpolation,指定插值方法。以下是一些常用的插值方法:
# PIL.Image.NEAREST最近邻插值
# PIL.Image.BILINEAR双线性插值(默认)
# PIL.Image.BICUBIC双三次插值
# PIL.Image.LANCZOS:Lanczos插值
trans = torchvision.transforms.Compose(trans)
# compose()将多个函数组合成一个新的函数,接受一系列函数作为参数,并返回一个新的函数,该函数按顺序应用这些函数
# compose()是从右到左应用函数的,即最右边的函数首先被应用,如需从左到右应用函数,可使用reduce()和operator.mul
train = torchvision.datasets.FashionMNIST(root = 'C:\\Users\\kongbai\\study\\数据集\\fashionMNIST', train = True, transform = trans, download = True)
test = torchvision.datasets.FashionMNIST(root = 'C:\\Users\\kongbai\\study\\数据集\\fashionMNIST', train = False, transform = trans, download = True)
return torch.utils.data.DataLoader(train, batch, shuffle = True), torch.utils.data.DataLoader(test, batch, shuffle = False)
dataset_train, dataset_test = dataset(32, resize = 64)
for x, y in dataset_train:
print(x.shape)
print(x.dtype)
print(y.shape)
print(y.dtype)
break
运行结果
torch.Size([32, 1, 64, 64])
torch.float32
torch.Size([32])
torch.int64