零基础-动手学深度学习-3.5图像分类数据集

 一、关于F-MNIST数据集的一些操作

import matplotlib.pyplot as plt
import torch
import torchvision #pytorch对于计算机视觉模型实现库
from torch.utils import data
from torchvision import transforms #对数据进行操作的步骤
from d2l import torch as d2l

d2l.use_svg_display() #用svg显示图片,清晰度高

#通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()#最简单的预处理
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
#                   训练数据集true,数据类型是pytorch的tensor而不是图片
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

#print(len(mnist_train), len(mnist_test))看一下训练集和测试集多长
#mnist_train[0][0].shape看看训练集第一个训练集的图片(1是标号)的形状

def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]


#imgs:图像列表,可以是 PyTorch 张量或 PIL 图像
#num_rows 和 num_cols:决定了网格的行数和列数,即图像展示的布局
#titles(可选参数):图像标题列表,可以为每幅图像提供标题
#scale:缩放比例,控制图像的大小
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    #axes 是一个包含子图的数组,d2l.plt.subplots 函数会创建一个 num_rows 行,num_cols 列的网格,每个网格位置用于显示一张图像
    #这里的-,表示占位符,因为subplots的输出有一个figure没有使用表明我们并不关心这个返回值
    axes = axes.flatten()
    #将二维的 axes 数组展开成一维数组,以便后面更方便地遍历处理每个子图
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        #使用 enumerate 对 axes 和 imgs 进行同步遍历,每次分别获得一个轴 ax 和一张图像 img
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
            #判断 img 是否是 PyTorch 张量(使用 torch.is_tensor 函数),
            #如果是,则需要调用 img.numpy() 将其转化为 NumPy 数组以便 imshow 可以显示它
        else:
            # PIL图片
            ax.imshow(img)
            #如果 img 不是张量,则假定它是一个 PIL 图像,直接用 imshow 显示
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        #如果 img 不是张量,则假定它是一个 PIL 图像,直接用 imshow 显示
        if titles:
            ax.set_title(titles[i])
            #如果有 titles 提供,则为每个图像设置对应的标题
    return axes

#X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))看一看训练集前几个样本及其标签
#show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))

#读取小批量
batch_size = 256
def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())
#看一下读取训练数据的所需的时间
timer = d2l.Timer()
for X, y in train_iter:
    continue
f'{timer.stop():.2f} sec'

二、定义load_data_fashion_mnist函数

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))


#我们通过指定参数来测试函数的图像大小调整功能
rain_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值