【“Transformers快速入门”学习笔记5】pytorch中的Dataset和DataLoaders

Pytorch 提供了 DataLoaderDataset 类(或 IterableDataset)专门用于处理数据,它们既可以加载 Pytorch 预置的数据集,也可以加载自定义数据。其中数据集类 Dataset(或 IterableDataset)负责存储样本以及它们对应的标签;数据加载类 DataLoader 负责迭代地访问数据集中的样本

Dataset

所有的数据集都必须继承自Dataset或IterableDataset

pytorch支持两种数据集

  • 映射型(Map-Style)数据集
    继承自Dataset类,表示一个从索引到样本的映射(索引可以不是整数),这样我们就可以方便地通过 dataset[idx] 来访问指定索引的样本。这也是目前最常见的数据集类型。映射型数据集必须实现 getitem() 函数,其负责根据指定的 key 返回对应的样本。一般还会实现 len() 用于返回数据集的大小。
  • 迭代型(Iterable-Style)数据集
    继承自 IterableDataset,表示可迭代的数据集,它可以通过 iter(dataset) 以数据流 (steam) 的形式访问,适用于访问超大数据集或者远程服务器产生的数据。 迭代型数据集必须实现 iter() 函数,用于返回一个样本迭代器 (iterator)。

下面分析具体代码

自定义映射行数据集(图像分类数据集)

import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset


class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)  # 从 CSV 文件中读取图像路径和标签
        self.img_dir = img_dir  # 图像文件所在的目录
        self.transform = transform  # 对图像进行的转换(如预处理)
        self.target_transform = target_transform  # 对标签进行的转换

    def __len__(self):
        return len(self.img_labels)  # 返回数据集中图像的数量

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])  # 获取图像的文件路径
        image = read_image(img_path)  # 使用 torchvision.io.read_image 读取图像文件并转换为张量
        label = self.img_labels.iloc[idx, 1]  # 获取图像对应的标签

        # 如果定义了图像转换函数,则对图像进行转换
        if self.transform:
            image = self.transform(image)

        # 如果定义了标签转换函数,则对标签进行转换
        if self.target_transform:
            label = self.target_transform(label)

        return image, label  # 返回图像和标签的元组
__init__()初始化数据集参数
__len__()返回数据集中样本的个数
__getitem__()映射行数据集的核心,根据给定的索引idx返回样本。

Dataloaders

训练模型时,需要先将数据集切分成若干batches,按batch将样本送入模型,循环这一过程,每完成一个周期称为一个epoch。

训练模型时,会在每次epoch开始前随机打乱样本以降低过拟合
pytorch提供DataLoader 类专门负责处理这些操作,除了基本的 dataset(数据集)和 batch_size (batch 大小)参数以外,还有以下常用参数:

  • shuffle:是否打乱数据集;
  • sampler:采样器,也就是一个索引上的迭代器;
  • collate_fn:批处理函数,用于对采样出的一个 batch 中的样本进行处理(例如前面提过的 Padding 操作)。

数据加载顺序和sampler类
对于迭代型数据集来说,数据的加载顺序直接由用户控制,用户可以精确地控制每一个 batch 中返回的样本,因此不需要使用 Sampler 类。
对于映射型数据集来说,由于索引可以不是整数,因此我们可以通过 Sampler 对象来设置加载时的索引序列,即设置一个索引上的迭代器。如果设置了 shuffle 参数,DataLoader 就会自动创建一个顺序或乱序的 sampler,我们也可以通过 sampler 参数传入一个自定义的 Sampler 对象。

from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler, RandomSampler
from torchvision import datasets
from torchvision.transforms import ToTensor

# 定义训练集和测试集的数据集对象,并进行相应的预处理
training_data = datasets.FashionMNIST(
    root="data",             # 数据集存储的根目录
    train=True,              # 加载训练集
    download=True,           # 如果数据集不存在,则下载数据集
    transform=ToTensor()     # 将图像转换为 PyTorch 张量
)

test_data = datasets.FashionMNIST(
    root="data",             # 数据集存储的根目录
    train=False,             # 加载测试集
    download=True,           # 如果数据集不存在,则下载数据集
    transform=ToTensor()     # 将图像转换为 PyTorch 张量
)

# 创建训练集和测试集的采样器对象
train_sampler = RandomSampler(training_data)    # 随机采样训练集
test_sampler = SequentialSampler(test_data)     # 顺序采样测试集

# 创建训练集和测试集的数据加载器,每个批次包含64个样本
train_dataloader = DataLoader(training_data, batch_size=64, sampler=train_sampler)
test_dataloader = DataLoader(test_data, batch_size=64, sampler=test_sampler)

# 获取并展示一个批次的训练集和测试集数据的形状
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")  # 打印训练集特征的形状
print(f"Labels batch shape: {train_labels.size()}")     # 打印训练集标签的形状

test_features, test_labels = next(iter(test_dataloader))
print(f"Feature batch shape: {test_features.size()}")   # 打印测试集特征的形状
print(f"Labels batch shape: {test_labels.size()}")      # 打印测试集标签的形状

输出:

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
torch.Size([28, 28])
Label: 5

下面解释一下上面出现的一些函数和方法:

  • iter():创建迭代器
  • next():获取下一个批次的数据
  • RandomSampler():随机采样器
  • SequentialSampler():顺序采样器
  • ToTensor():转换成pytorch张量

批处理函数collate_fn
对每一个采样出的batch中的样本进行处理,模型collate_fn会进行以下操作:

  • 添加新的batch维
  • 自动将python数值和numpy数组转成pytorch张量
  • 保留原始的数据结构,例如输入是字典的话,它会输出一个包含同样键 (key) 的字典,但是将值 (value) 替换为 batched 张量(如何可以转换的话)。
  • 17
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值