图像分类数据集(PyTorch)

       本代码基于《动手学深度学习》Pytorch版,第三章线性回归网络,第五节图像分类数据集。对代码进行修改,增加注释,供学习使用。

导入相关库

import matplotlib_inline
import torchvision
import matplotlib.pyplot as plt
import torch
import time
import numpy as np
plt.rcParams['font.sans-serif'] = ['SimHei']
# 显示中文,在Windows系统上,选择SimHei(黑体)或其他中文字体,将其设置为Matplotlib的默认字体
# 设置JupyterNotebook中Matplotlib默认图像显示格式为SVG,让图表更易嵌入网页,并可在各种分辨率和设备上呈现出高质量效果
# SVG(ScalableVectorGraphics,可缩放矢量图形)基于XML的矢量图形格式,可缩放和调整大小而不失真
# SVG图像具有更高的质量和可扩展性,在需更清晰的图像或需要缩放图像时使用
# 在JupyterNotebook中,默认Matplotlib生成的图像会以PNG格式显示
def use_svg_display():
    matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
    # backend_inline是JupyterNotebook设置Matplotlib后端的模块,使生成的Matplotlib图像直接嵌入笔记本中,而不是单独窗口中显示
    
    # set_matplotlib_formats()设置Matplotlib图像显示格式,允许为Matplotlib生成的图像指定所需的输出格式
    # 可接受其他参数,quality(设置JPEG输出格式图像质量),dpi(设置图像每英寸点数)等
# 设置图像显示格式
use_svg_display()

加载并处理Fashion-MNIST数据集

# 加载并处理Fashion-MNIST数据集
trans = torchvision.transforms.ToTensor()
# ToTensor()数据预处理,将输入数据(如PIL图像或NumPy数组)转换为张量
# 将图像像素值从0-255范围转换为0-1范围(浮点数)
# 将图像的通道顺序从HWC(高度、宽度、通道)转换为CHW(通道、高度、宽度),这是PyTorch常用的格式
# 通常与其他数据预处理操作一起使用,以便将图像数据准备为神经网络训练所需的格式
train = torchvision.datasets.FashionMNIST(root = 'C:\\Users\\kongbai\\study\\数据集\\fashionMNIST', train = True, transform = trans, download = True)
# FashionMNIST()加载和处理Fashion-MNIST数据集
# root:数据集的存储路径,默认为./data
# train:是否加载训练数据,默认为True
# transform:用于数据预处理的变换,默认为None
# target_transform:用于目标(标签)预处理的变换,默认为 None
# download:是否下载数据集,如果数据集不存在,则下载,默认为True
# 默认下载数据集,如果数据集已存在,则不会再次下载,如果需要强制下载,可将download参数设置为True
test = torchvision.datasets.FashionMNIST(root = 'C:\\Users\\kongbai\\study\\数据集\\fashionMNIST', train = False, transform = trans, download = True)
print(len(train))
print(len(test))
print(train[0][0].shape)
print(test[0][0].shape)

运行结果

60000
10000
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])

返回Fashion-MNIST数据集的文本标签

def get_label(labels):
    label = ['T恤', '裤子', '套衫', '连衣裙', '外套', '凉鞋', '衬衫', '运动鞋', '包', '短靴']
    return [label[int(i)] for i in labels]
    # [expression for item in iterable if condition]
    # 列表推导式是Python的语法,根据已有的列表或其他可迭代对象快速生成新的列表,让代码更加简洁易读,提高代码执行效率
    # 可用于生成新列表,可嵌套使用,以实现更复杂的逻辑
    # 过度使用可能会导致代码难以阅读和维护,在实际编程中要根据具体情况权衡使用
    # expression对item的操作表达式,生成新列表的元素
    # item可迭代对象中的每个元素
    # iterable可迭代对象
    # condition可选的条件表达式,过滤可迭代对象中的元素

可视化Fashion-MNIST数据集图像

def show(imgs, rows, cols, titles = None, scale = 1.5):
    _, axes = plt.subplots(rows, cols, figsize = (cols * scale, rows * scale))
    axes = axes.flatten()
    # flatten()将多维数组展平为一维数组,返回新的数组,原始数组不被修改
    # 如果想要在原地修改数组,可使用ravel(),与flatten()功能类似,返回原始数组的一个视图(view)而不是副本
    # 如果使用的是Python列表,需要使用其他方法来展平列表,可使用列表推导式
    for i, (ax, img) in enumerate(zip(axes, imgs)):
    # enumerate(iterable, start = 0)遍历可迭代对象,同时返回每个元素的索引和值
    # 返回枚举对象,可使用next()或将其转换为列表、元组等数据结构
    # iterable可迭代对象
    # start可选参数,指定索引的起始值,默认为0
        
    # zip(*iterables)将多个可迭代对象组合成一个新的可迭代对象
    # 按顺序从每个可迭代对象中取出相同位置的元素,然后将这些元素组合成新的元组,最后所有这些元组会被组合成新的可迭代对象
    # iterables一个或多个可迭代对象
    # zip()会按最短的可迭代对象进行组合,如果输入的可迭代对象长度不同,较长的可迭代对象中多余的元素将被忽略
        if torch.is_tensor(img):
        # is_tensor()检查对象是否为张量
            ax.imshow(img.numpy())
            # imshow(image, cmap = None)显示图像,可接受多种类型的输入,并将其显示在当前的Axes对象上
            # 默认使用matplotlib的颜色映射显示图像,灰度图像通常没有问题,彩色图像,可能需要设置cmap参数为None,以便显示颜色
            # 还可接受其他参数,控制图像的显示效果
            
            # numpy()将张量转换为NumPy数组,要求张量位于CPU上,如果张量位于GPU上,需先将其移动到CPU上,然后再使用 
            # 返回的NumPy数组与原始PyTorch张量共享内存,如果修改NumPy数组,原始张量的值也会发生变化,反之亦然
            # 如果希望创建独立的NumPy数组副本,可使用numpy_array.copy()
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        # axes属性,访问当前坐标轴对象
        
        # get_xaxis()获取当前坐标轴的x轴对象
        
        # set_visible()设置对象可见性,可应用于各种matplotlib对象,如轴(Axes)、刻度(Ticks)、图例(Legend)等
        ax.axes.get_yaxis().set_visible(False)
        # get_yaxis()获取当前坐标轴的y轴对象
        if titles:
            ax.set_title(titles[i])
    return axes
x, y = next(iter(torch.utils.data.DataLoader(train, batch_size = 10)))
# next()获取迭代器的下一个元素
# 当迭代器中没有更多元素时,将引发StopIteration异常,可使用try-except处理
# 接受可选的第二个参数,该参数用作默认值,当迭代器中没有更多元素时,将返回此默认值,而不引发异常

# iter()从可迭代对象创建一个迭代器,允许遍历可迭代对象的元素,一次一个
# 迭代器只能遍历一次,已经遍历了所有元素后,再次尝试使用next()函数将引发StopIteration异常

# DataLoader()批量加载数据集,将数据集分成多个批次,并支持多线程/多进程数据加载,从而加速数据加载过程
# dataset(Dataset)要加载的数据集对象,此对象需实现__len__()和__getitem__()方法
# batch_size(int)可选,每批次的大小,默认为1
# shuffle(bool)可选,是否在每个epoch开始时打乱数据,默认为False
# sampler(Sampler)可选,自定义采样器,决定从数据集中提取哪些样本,如果提供sampler,则shuffle参数将被忽略
# num_workers(int)可选,数据加载的子进程数,默认为0,表示不使用多进程
# collate_fn(callable)可选,合并批次数据的函数,默认使用torch.utils.data._utils.collate.default_collate
# pin_memory(bool)0可选,如果为True,则将数据加载到固定内存中,以便更快地传输到GPU,默认为False
# drop_last(bool)可选,如果为True,则丢弃最后一个不完整的批次,默认为False
# timeout(numeric)可选,等待从工人进程接收数据的超时值(以秒为单位),默认为0
# worker_init_fn(callable)可选,在每个工人进程中调用,初始化工人进程,默认不做任何操作
show(x.reshape(10, 28, 28), 2, 5, titles = get_label(y));

运行结果

计时器

class Timer:
    # 记录多次运行时间
    def __init__(self):
        self.times = []
        self.start()
    
    # 启动计时器
    def start(self):
        self.tik = time.time()
    
    # 停止计时器并将时间记录在列表中
    def stop(self):
        self.times.append(time.time() - self.tik)
        return self.times[-1]

    # 返回平均时间
    def avg(self):
        return sum(self.times) / len(self.times)
    
    # 返回时间总和
    def sum(self):
        return sum(self.times)
    
    # 返回累计时间
    def cumsum(self):
        return np.array(self.times).cumsum().tolist()

构建数据迭代器读取小批量数据并计算读取所需时间

batch = 256
train_iter = torch.utils.data.DataLoader(train, batch, shuffle = True, num_workers = 4)
timer = Timer()
for x, y in train_iter:
    continue
print(f'{timer.stop():.2f} sec')

运行结果

3.70 sec

整合组件

def dataset(batch, resize = None):
    trans = [torchvision.transforms.ToTensor()]
    if resize:
        trans.insert(0, torchvision.transforms.Resize(resize))
        # insert()是Python列表的方法,在列表的指定位置插入一个元素,接受两个参数:插入元素的索引,要插入的元素
        # 如果指定的索引超出列表的范围,将引发IndexError异常
        
        # Resize()调整图像的大小,接受一个表示新尺寸的元组作为参数,表示新图像的宽度和高度
        # 返回新的调整大小的图像对象,而不修改原始图像对象
        # 还支持额外的参数,例如interpolation,指定插值方法。以下是一些常用的插值方法:
        # PIL.Image.NEAREST最近邻插值
        # PIL.Image.BILINEAR双线性插值(默认)
        # PIL.Image.BICUBIC双三次插值
        # PIL.Image.LANCZOS:Lanczos插值
    trans = torchvision.transforms.Compose(trans)
    # compose()将多个函数组合成一个新的函数,接受一系列函数作为参数,并返回一个新的函数,该函数按顺序应用这些函数
    # compose()是从右到左应用函数的,即最右边的函数首先被应用,如需从左到右应用函数,可使用reduce()和operator.mul
    train = torchvision.datasets.FashionMNIST(root = 'C:\\Users\\kongbai\\study\\数据集\\fashionMNIST', train = True, transform = trans, download = True)
    test = torchvision.datasets.FashionMNIST(root = 'C:\\Users\\kongbai\\study\\数据集\\fashionMNIST', train = False, transform = trans, download = True)
    return torch.utils.data.DataLoader(train, batch, shuffle = True), torch.utils.data.DataLoader(test, batch, shuffle = False)
dataset_train, dataset_test = dataset(32, resize = 64)
for x, y in dataset_train:
    print(x.shape)
    print(x.dtype)
    print(y.shape)
    print(y.dtype)
    break

运行结果

torch.Size([32, 1, 64, 64])
torch.float32
torch.Size([32])
torch.int64
  • 9
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值