【机器学习】024_Softmax模型Part.2_图片分类数据集

准备所需导入的库

代码:

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
# 将 d2l 库生成的图形的显示格式设置为 SVG(可缩放矢量图形)格式
d2l.use_svg_display()

一、读取多类分类图片数据集

Fashion-MNIST数据集

通过内置函数将Fashion-MNIST数据集下载并读取到内存中

## 使用框架内置函数将Fashion-MNIST数据集下载并读取到内存中
# 通过ToTensor实例将图像数据从PIL类型变换为32位浮点数
# 除以255使得所有像素的值在0到1之间
trans = transforms.ToTensor()
# 定义训练数据集和测试数据集,从网上下载
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True
)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True
)
# 显示数据集的数据量(图片张数)
print(len(mnist_train))
print(len(mnist_test))

· Fashion-MNIST由10个类别的图像组成, 每个类别由训练数据集(train dataset)中的6000张图像 和测试数据集(test dataset)中的1000张图像组成。 因此,训练集和测试集分别包含60000和10000张图像。 测试数据集不会用于训练,只用于评估模型性能。

二、可视化数据样本

将图片数据集下载下来后,要对数据集进行可视化,将图片显示出来。

Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。

1. 定义一个函数,用于在数字标签索引及其文本名称之间进行转换。

def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

2. 创建一个可视化函数,将PIL图片连同标签显示出来

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
d2l.plt.show()

数据集中前几个样本的图像和对应标签:

三、读取小批量数据

为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。 回顾一下,在每次迭代中,数据加载器每次都会读取一小批量数据,大小为batch_size。 通过内置数据迭代器,我们可以随机打乱所有样本,从而无偏见地读取小批量。

使用4进程读取数据,也可以使用8进程等(根据CPU选择)。

代码:

batch_size = 256

def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

        · DataLoader表示一个数据加载器对象,用于批量加载数据。它接收训练数据集、批次大小、是否打乱数据以及读取数据的进程数量等参数。

        · 在这里,mnist_train 是要加载的训练数据集,batch_size 是每个批次中的样本数量,shuffle=True 表示在每个 epoch(训练轮次)开始时打乱数据集,然后使用4个进程来读取数据。

        · 创建数据加载器对象,可以方便地迭代访问训练数据集中的批次数据,以供模型训练使用。

查看读取训练数据所需的时间:

timer = d2l.Timer()
for X, y in train_iter:
    continue
f'{timer.stop():.2f} sec'

        · 使用 for 循环遍历 train_iter,每次迭代获取一个批次的训练数据,其中 X 是输入数据,y 是对应的标签数据。在这个例子中,由于循环体内没有实际的操作,使用 continue 关键字跳过当前迭代,即不执行任何操作。timer则用于计时。

        · 接下来,使用 timer.stop() 方法停止计时,并使用 f-string 格式化字符串的方式将计时结果转换为保留两位小数的字符串。

四、整合图片数据集组件

定义load_data_fashion_mnist函数,用于获取和读取Fashion-MNIST数据集。 这个函数返回训练集和验证集的数据迭代器。 此外,这个函数还接受一个可选参数resize,用来将图像大小调整为另一种形状。

# 整合所有组件,创建图像分类数据集
def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    # 如果resize参数存在,则将 transforms.Resize(resize) 实例插入到 trans 列表的第一个位置
    # 这样,可以确保在将图像数据转换为张量之前,先对图像进行调整大小的操作。
    if resize:
        trans.insert(0, transforms.Resize(resize))
    # 使用 transforms.Compose 函数将 trans 列表中的实例组合成一个变换序列——方便一次性对图像数据进行多个变换操作
    trans = transforms.Compose(trans)
    # 下载数据集
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    # 返回两个数据加载器对象,第一个是训练数据,第二个是测试数据
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))
# 获取数据加载器对象,并指定图像大小为64x64
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break

通过该函数,我们可以很方便地得到图像分类数据集,获取训练数据和测试数据的数据加载器,以供后续模型训练使用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值