【李沐3】3.5、图像分类数据集

# %matplotlib inline
# 上述代码是一个注释,用于在Jupyter Notebook等环境中显示Matplotlib绘图的结果在单元格内部显示,而不是弹出新的窗口。

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

# 导入必要的库和模块
# - torch:PyTorch库,用于构建和训练神经网络
# - torchvision:PyTorch中用于处理图像数据的库
# - torch.utils.data:PyTorch中用于处理数据加载的模块
# - torchvision.transforms:用于定义和应用数据转换的模块
# - d2l.torch:Dive into Deep Learning(《动手深度学习》)书中提供的PyTorch实用函数和工具

d2l.use_svg_display()
# 设置绘图的显示格式为SVG格式,这可以使绘图在Jupyter Notebook中以矢量图形的形式显示,更清晰和美观。

1、读取数据集

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0〜1之间
trans = transforms.ToTensor()

# 创建FashionMNIST数据集的训练集实例
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data",  # 数据集存放的根目录
    train=True,      # 表示加载训练集
    transform=trans, # 数据变换,将图像数据转换为Tensor格式并归一化
    download=True    # 是否下载数据集(如果尚未下载的话)
)
# 创建FashionMNIST数据集的测试集实例
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data",  # 数据集存放的根目录
    train=False,     # 表示加载测试集
    transform=trans, # 数据变换,将图像数据转换为Tensor格式并归一化
    download=True    # 是否下载数据集(如果尚未下载的话)
)

数据集介绍在这里插入图片描述
那么如何查看数据集中图片大小和通道数呢?以及训练和验证数据多少呢?
在这里插入图片描述

下面代码是将数字标签转换为文本标签

def get_fashion_mnist_labels(labels): #@save
    """
    返回Fashion-MNIST数据集的文本标签
    参数:
        labels: 包含数值标签的列表或数组
    返回:
        包含对应文本标签的列表
    """
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

上面是啥意思呢?就是比如1表示苹果,这里以前标记的是1,现在转换为苹果

下面是可视化

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
    """
    绘制图像列表
    参数:
        imgs: 包含图像的列表
        num_rows: 图像展示的行数
        num_cols: 图像展示的列数
        titles: 可选参数,图像标题的列表
        scale: 可选参数,控制图像的缩放比例
    返回:
        无返回值,显示绘制的图像
    """
    # 计算绘图区域的尺寸
    figsize = (num_cols * scale, num_rows * scale)
    # 创建一个具有指定尺寸的子图
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    # 将子图数组展平,以便逐个访问每个子图
    axes = axes.flatten()
    
    # 遍历图像列表并在每个子图中绘制图像
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 如果图像是PyTorch的张量,将其转换为NumPy数组并在子图上显示
            ax.imshow(img.numpy())
        else:
            # 如果图像是PIL图像,直接在子图上显示
            ax.imshow(img)
        # 隐藏子图的x轴和y轴
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            # 如果提供了标题列表,设置当前子图的标题
            ax.set_title(titles[i])
    
    # 返回绘制的子图数组
    return axes

2、读取小批量
问:
(1)一个进程通常占用一个核心吗
是的

batch_size = 256
def get_dataloader_workers(): #@save
    """使用4个进程来读取数据"""
    return 4

# 创建训练数据迭代器
train_iter = data.DataLoader(
    mnist_train,                       # 使用的数据集实例
    batch_size,                        # 每个批次的样本数量
    shuffle=True,                      # 是否在每个epoch前打乱数据顺序
    num_workers=get_dataloader_workers() # 用于加载数据的进程数量
)

(2)上面的data在哪里定义的?
看文章开头定义的

3、整合所有组件
就是合成上边所有的代码

def load_data_fashion_mnist(batch_size, resize=None): #@save
    """
    下载Fashion-MNIST数据集,然后将其加载到内存中
    参数:
        batch_size: 批次大小,用于小批量训练
        resize: 可选参数,指定图像调整的大小
    返回:
        包含训练数据迭代器和测试数据集的元组
    """
    # 创建数据变换列表,将图像转换为Tensor格式
    trans = [transforms.ToTensor()]
    
    # 如果提供了resize参数,将图像调整大小添加到变换列表
    if resize:
        trans.insert(0, transforms.Resize(resize))
        
    # 将变换列表组合成一个组合变换
    trans = transforms.Compose(trans)
    
    # 创建FashionMNIST数据集的训练集实例
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data",              # 数据集存放的根目录
        train=True,                  # 表示加载训练集
        transform=trans,             # 数据变换,包括调整大小和转换为Tensor
        download=True                # 是否下载数据集(如果尚未下载的话)
    )
    
    # 创建FashionMNIST数据集的测试集实例
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data",              # 数据集存放的根目录
        train=False,                 # 表示加载测试集
        transform=trans,             # 数据变换,包括调整大小和转换为Tensor
        download=True                # 是否下载数据集(如果尚未下载的话)
    )
    
    # 创建训练数据迭代器,并指定批次大小、是否打乱顺序和数据加载进程数量
    train_data = data.DataLoader(
        mnist_train,                 # 使用的训练数据集实例
        batch_size,                  # 每个批次的样本数量
        shuffle=True,                # 是否在每个epoch前打乱数据顺序
        num_workers=get_dataloader_workers()  # 数据加载进程数量
    )
    
    # 创建测试数据迭代器,并指定批次大小、不打乱顺序和数据加载进程数量
    test_data = data.DataLoader(
        mnist_test,                  # 使用的测试数据集实例
        batch_size,                  # 每个批次的样本数量
        shuffle=False,               # 不打乱数据顺序
        num_workers=get_dataloader_workers()  # 数据加载进程数量
    )
    
    # 返回训练数据迭代器和测试数据迭代器的元组
    return train_data, test_data

下面这张图片是对上面的的数据进行调用和,查看数据大小是否改变
在这里插入图片描述

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值