第一章 动手学深度学习( PyTorch)笔记 ------ 线性神经网络(三)


本文仅作为个人学习笔记用,如有错误请指正,欢迎大家讨论学习。本博客内容来自动手学深度学习


一、图像分类数据集

导包

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
from IPython import display

数据集下载

创建子目录下载训练集以及测试集数据

trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

标签管理

接受label参数后,使用列表推导式来生成并返回一个新的列表。对于输入列表 labels中的每个元素 i,它首先将 i 转换为整数(在大多数情况下,labels 中的元素可能已经是整数类型,这一步可能是为了确保兼容性或处理特殊情况),然后使用这个整数作为索引从 text_labels 列表中取出对应的文本标签。最终,这个表达式生成一个包含所有对应文本标签的新列表,并将其返回。

def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

样本可视化

figsize由接收的 num_rows, num_cols决定,scale为设置好的参数;
创建一个子图网格axes,行列由num_rows, num_cols,其大小由figsize确定;
axes.flatten()将二维的子图数组转换为一维,以便后续遍历
接着,函数遍历每个子图和对应的图像,根据图像的类型,使用ax.imshow()方法显示图像。
对于每个子图,还关闭了x轴和y轴的显示,并可选地设置了标题。
最后,函数返回包含所有子图的数组。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

加载图像

mnist_train数据加载器中获取下一批数据,其中batch_size=18意味着每次加载18个样本。X是图像数据,y是对应的标签。
在调用show_images函数之前,先对X进行了重塑(reshape),因为从数据加载器获取的X可能是四维的(批量大小、通道数、高度、宽度),而Fashion-MNIST是灰度图像,通道数为1,所以这里将其重塑为三维(18个28x28的图像)。然后,指定了2行9列的布局来显示这些图像,并使用get_fashion_mnist_labels函数将y中的数字标签转换为文本标签作为图像的标题。
最后,调用d2l.plt.show()(“动手学深度学习”的库封装matplotlib的功能)来显示图像和标题。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
d2l.plt.show()

完整代码及运行结果

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l


# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
d2l.plt.show()

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值