在 pytorch 中加载和使用图像分类数据集 Fashion-MNIST

  • 参考:《动手学深度学习》(Pytorch)版 3.5 节
  • 注:本文是 jupyter notebook 文档转换而来,部分代码可能无法直接复制运行!

  • 图像分类数据集中最常用的是手写数字识别数据集MNIST,但大部分模型在MNIST上的分类精度都超过了95%,为了更直观地观察算法之间的差异,本文介绍一个图像内容更加复杂的数据集 Fashion-MNIST,这个数据集难度比 MNIST 高,但是尺寸并不大,只有几十M,没有GPU的电脑也能吃得消
  • 该数据集可以利用 torchvision 包来下载和处理,该包包含以下几个核心模块
    1. torchvision.datasets: 提供加载数据的函数及常用数据集接口;
    2. torchvision.models: 包含常用的模型结构(含预训练模型),如 AlexNet、VGG、ResNet 等;
    3. torchvision.transforms: 提供常用的图片变换方法,例如裁剪、旋转等;
    4. torchvision.utils: 提供其他的一些有用的方法
  • 开始介绍前,先导入包
    import torch
    import torchvision
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt
    import time
    import numpy as np
    from IPython import display
    

1. 获取数据集

  • 通过 torchvision.datasets.FashionMNIST 方法获取数据集

    mnist_train = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=True, transform=transforms.ToTensor())
    mnist_test = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=False, transform=transforms.ToTensor())
    

    参数说明

    1. root 参数指定数据集保存路径

    2. train 参数指定获取训练集还是测试集

    3. download 参数若设置为 True,则在发现 root 路径下没有数据集时自动从网上下载,若已有数据集则不动作

    4. transform = transforms.ToTensor() 使所有数据转换为 Tensor,如果不转换则返回的是 PIL 图片

      transforms.ToTensor() 将 “尺寸为 H × W × C H \times W \times C H×W×C 且数据位于 [ 0 , 255 ] [0, 255] [0,255] 的PIL图片” 或者 “数据类型为 np.uint8 的NumPy数组” 转换为 “尺寸为 C × H × W C \times H \times W C×H×W 且数据类型为 torch.float32 且位于 [0.0, 1.0] 的Tensor”

      注意 transforms.ToTensor() 在内的一些关于图片的函数默认输入为 uint8 类型,如果不是则可能得到不想要的结果,所以如果用 [ 0 , 255 ] [0,255] [0,255] 的像素值表示图片数据,则一律将其类型设置为 uint8,以免不必要的bug

  • 这里加载的 mnist_trainmnist_test 都是 torch.utils.data.Dataset 的子类,一些常用方法如下

    print(type(mnist_train))
    print(len(mnist_train), len(mnist_test)) # 用 len() 获取该数据集的大小
    
    feature, label = mnist_train[0]          # 通过下标来访问任意样本
    print(feature.shape, label)              # [Channel , Height , Width] label,注意由于数据集中都是灰度图,通道数为 1
    
    '''
    torchvision.datasets.mnist.FashionMNIST
    60000 10000
    torch.Size([1, 28, 28]) 9
    '''
    
  • Fashion-MNIST中一共包括了10个类别,分别为

    1. t-shirt(T恤)
    2. trouser(裤子)
    3. pullover(套衫)
    4. dress(连衣裙)
    5. coat(外套)
    6. sandal(凉鞋)
    7. shirt(衬衫)
    8. sneaker(运动鞋)
    9. bag(包)
    10. ankle boot(短靴)

    使用以下函数将数值标签列表转成相应的文本标签列表

    def get_fashion_mnist_labels(labels):
        text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                       'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
        return [text_labels[int(i)] for i in labels]
    
  • 使用以下函数在一行里绘制多个图像和对应的标签

    def show_fashion_mnist(images, labels):
        display.set_matplotlib_formats('svg')  # Use svg format to display plot in jupyter
        
        _, figs = plt.subplots(1, len(images), figsize=(12, 12))
        for f, img, lbl in zip(figs, images, labels):
            f.imshow(img.view((28, 28)).numpy())
            f.set_title(lbl)
            f.axes.get_xaxis().set_visible(False)
            f.axes.get_yaxis().set_visible(False)
        plt.show()
    
  • 随机显示 10 个样本

    X, y = [], []
    for i in np.random.randint(0,60000,size = 10).tolist():
        X.append(mnist_train[i][0])
        y.append(mnist_train[i][1])
    show_fashion_mnist(X, get_fashion_mnist_labels(y))
    

    这里我遇到一个报错,请参考 ‘OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program’,我删除了虚拟环境中的 libiomp5md.dll 解决此问题

在这里插入图片描述

2. 读取小批量

  • 在实践中,数据读取经常是训练的性能瓶颈,torch.utils 模块提供的 DataLoader 方法允许我们方便地使用多进程来加速数据读取

  • mnist_traintorch.utils.data.Dataset 的子类,所以我们可以将其传入 torch.utils.data.DataLoader 来创建一个读取小批量数据样本的DataLoader 实例,在创建时

    1. 通过参数 num_workers 来指定读取数据的进程数量
    2. 通过 shuffle 参数指定读取时是否打乱
    batch_size = 256
    if sys.platform.startswith('win'): # 判断操作系统为 windows
        num_workers = 4 # 使用 4 个进程同时读取
    else:
        num_workers = 0 # 0表示不用额外的进程来加速读取数据
    
    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
  • 查看读取一遍数据的耗时

    start = time.time()
    for X, y in train_iter:
        continue
    print('%.2f sec' % (time.time() - start))
    

    经测试,我的笔记本电脑在不使用多进程加速时耗时 5.88s,使用后减少到 3.18s

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
Fashion-MNIST数据集是一个包含10个类别的图像数据集。这些类别分别是:t-shirt(T恤),trouser(牛仔裤),pullover(套衫),dress(裙子),coat(外套),sandal(凉鞋),shirt(衬衫),sneaker(运动鞋),bag(包),ankle boot(短靴)。 Fashion-MNIST数据集MNIST手写数据集不同,它提供了更加多样化的图像样本,为深度学习模型的训练和评估提供了更具挑战性的任务。可以通过torch.utils.data.DataLoader来读取Fashion-MNIST数据集的小批量数据样本,该数据集也是torch.utils.data.Dataset的子类,因此可以直接传入DataLoader来创建一个数据加载器实例。 如果你想了解更多关于Fashion-MNIST数据集的内容,你可以参考相关的文档或教程,并且可以使用批量显示图像的方式来直观地了解数据集的内容。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [【深度学习系列】——Fashion-MNIST数据集简介](https://blog.csdn.net/weixin_45666566/article/details/107812603)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *3* [PyTorch深度学习(三):Fashion-MNIST 数据集介绍](https://blog.csdn.net/weixin_48261286/article/details/121195427)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

云端FFF

所有博文免费阅读,求打赏鼓励~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值