动手学深度学习 图像分类数据集(一) Fashion-MNIST的获取与查看

动手学深度学习 图像分类数据集(一) Fashion-MNIST的获取与查看

动手学深度学习 图像分类数据系列:


Fashion-MNIST在书中多次使用,本文的内容是讲解如何获取并查看此数据集


1.下载数据集

使用torchvision.datasets来下载数据集

  • root 用来指定下载后保存的位置(如果已经存在则不会下载)
  • download表示是否要下载
  • train 表示获取训练数据集或测试数据集
  • transform代表对图像的操作, 这里仅仅使用了ToTensor()把图像数据转换为Tensor类型
    其格式为( C ∗ H ∗ W C*H*W CHW)

更多transform的操作可以点击这篇文章来查看

书本原话: 
注意:由于像素值为0到255的整数,所以刚好是uint8所能表示的范围,包括
transforms.ToTensor() 在内的一些关于图片的函数就默认输入的是uint8型,若不是,可能不会报错
但可能得不到想要的结果。所以,如果用像素值(0-255整数)表示图片数据,那么一律将其类型设置成
uint8,避免不必要的bug。
import torchvision
import torchvision.transforms as transforms
mnist_train = torchvision.datasets.FashionMNIST(root=r'D:\Source\Datasets\FashionMNIST', train=True, download=True,
                                                transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root=r'D:\Source\Datasets\FashionMNIST', train=False, download=True,
                                               transform=transforms.ToTensor())

查看一下读取的结果
在这里插入图片描述

2.查看数据集结构

对训练集切片查看一下数据类型和标签类型
在这里插入图片描述
这里的标签已经转换为数值型数据来存储
所以我们可以编写一个函数将其转换为 图像数据集原本对应的标签

def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

在这里插入图片描述

3.查看图片与标签

先提取出其中的一张图片与标签来查看

img, label = mnist_train[0]
title = get_fashion_mnist_labels([label])[0] # 获取标签
plt.imshow(img.view((28,28)).numpy())	# 数据格式转换
plt.title(title)	# 设置标题
plt.savefig('test.jpg')	# 存储图片

在这里插入图片描述
查看多个图片和标签(以前十张为例)

import matplotlib.pyplot as plt
def show_fashion_mnist(images, labels):
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
X, y = [], []
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
plt.show()

在这里插入图片描述

4.按小批次读取数据集

使用DataLoader 它可以允许多线程来加速数据读取

具体的可以看下面链接中的文章,有对DataLoaderDataset的详细介绍
Pytorch 快速详解如何构建自己的Dataset完成数据预处理(附详细过程)

from torch.utils.data import DataLoader
import sys
batch_size = 256
if sys.platform.startswith('win'):
    # 0表示不用额外的进程来加速读取数据
    num_workers = 0
else:
    num_workers = 4
train_iter = DataLoader(mnist_train,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=num_workers)
test_iter = DataLoader(mnist_test,
                       batch_size=batch_size,
                       shuffle=False,
                       num_workers=num_workers)

DataLoader是个可遍历的对象

start = time()
for X, y in train_iter:
	continue
print('%.2f sec' % (time() - start))

可以通过上述代码来查看读取一遍训练集需要的时间

引用资料来源

本文内容来自吴振宇博士的Github项目
对中文版《动手学深度学习》中的代码进行整理,并用Pytorch实现
【深度学习】李沐《动手学深度学习》的PyTorch实现已完成

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Joker-Tong

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值