PyTorch 中的 torch.utils.data 解析

PyTorch 中的 torch.utils.data 解析

在 PyTorch 中,提供了一个处理数据集的工具包 torch.utils.data。这里来简单介绍这个包的结构。以下内容翻译和整理自 PyTorch 官方文档

概述

PyTorch 数据集处理包 torch.utils.data 的核心是 DataLoader 类。该类的构造函数签名为

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

它构造一个 可迭代对象 loader,代表经过 “加工” 后的数据集。所谓的 “加工” 过程,是由构造函数参数表指定的,它包括:

  • 设置数据的加载顺序(通过修改 shufflesampler 参数)
  • 对数据进行 batching 处理(通过修改 batch_sizebatch_samplercollate_fndrop_last 参数)
  • 实现 multi-process loading,memory pinning 等(此处不涉及)

一旦构造了 DataLoader 对象 loader,就可以用

for data in loader:
    # data 是数据集中的一组数据,且已转换成 Tensor

来加载数据。缺省情况下,PyTorch 会对数据进行 auto-batching,此时 data 对应一个 batch 的数据。

可以将 loader 理解成一个 生成器,其定义按情况可分为:(出现一些概念之后都会解释)

## 启用 auto-batching

# 对 map-style 数据集
for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])
# 对 iterable-style 数据集
dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

## 不启用 auto-batching (设置 batch_size=None 和 batch_sampler=None)

# 对 map-style 数据集
for index in sampler:
    yield collate_fn(dataset[index])
# 对 iterable-style 数据集
for data in iter(dataset):
    yield collate_fn(data)

数据集

DataLoader 构造函数中的必需参数 dataset 代表一个数据集。数据集主要分为两种:

  • map-style 数据集:它是 torch.utils.data.Dataset 的子类,重载了 __getitem____len__ 运算符,可以随机访问数据集中的数据
  • iterable-style 数据集:它是 torch.utils.data.IterableDataset 的子类,是可迭代对象

数据加载顺序

手动定义 sampler

可以通过指定 sampler 参数来手动设置加载顺序。一个 sampler 是可迭代对象,其迭代的每一个值表示下一个待加载数据的 key/index。它应当实例化泛型类 torch.utils.data.Sampler[int] 的一个子类,并且重载 __iter____len__ 函数,具体地讲:

  • 构造函数 __init__(self, data_source, *args) 必须提供一个重载了 __len__ 的数据集 data_source 作为参数
  • __iter__ 返回一个整型迭代器,其每迭代一次的返回值为下一个待加载数据的 key/index
  • __len__ 返回要加载的数据总数

但是要注意,只有 map-style 数据集才可定义 sampler,因为 iterable-style 不一定支持随机访问。

使用内置 sampler

在模块 torch.utils.data.sampler 中定义了一些内置的 sampler,通常来说已经够用了。在缺省 sampler 参数的情况下,如果指定参数 shuffle=False 将使用 SequentialSampler,即按顺序加载整个数据集;如果指定 shuffle=True 则使用 RandomSampler,即随机打乱数据后加载整个数据集。但是注意,不允许同时指定 sampler 参数和 shuffle 参数。

另外一些 sampler 可以参见模块源代码。

数据的 batching

在训练神经网络的时候经常需要将数据分成 mini-batch。PyTorch 本身提供了 auto-batching 的功能,也可以通过修改参数 batch_sizebatch_samplerdrop_lastcollate_fn 进行自定义 batching。

使用 batch sampler

在定义了 batching 后,PyTorch 会一次性输入多个(数量为 batch_size)数据。这时候需使用 batch sampler 来取代普通的 sampler。

通过指定 batch_sampler 参数,可以手动实现想要的 batch sampler。一个 batch_samplertorch.utils.data.BatchSampler 的实例。在 PyTorch 源代码中,该类继承了 Sampler[List[int]],并且封装了一个 sampler

  • 构造函数签名为 __init__(self, sampler, batch_size: int, drop_last: bool),其中
    • sampler 是一个可迭代对象,代表被封装的 sampler
    • batch_size 代表每个 batch 的数据量
    • drop_last 表示要不要把最后一个不足 batch_size 的 batch 丢掉
  • __iter__ 返回一个迭代器,它每迭代一次,返回一个 List[int],表示下一个 batch 的 key/index 列表
  • __len__ 返回 batch 总数

请注意,

  • 如果自定义了 batch_sampler,那么不能再指定 samplershufflebatch_sizedrop_last 参数
  • 如果没有指定 batch_sampler 参数,但 batch_size 不为 None,则 DataLoader 构造函数自动使用自定义的 sampler 或由 shuffle 指定的内置 sampler,以及 batch_sizedrop_last 参数封装 batch sampler
  • 如果既没有指定 batch_sampler 参数,又设置 batch_sizeNone,则禁用 auto-batching,每加载一次输出的是单个数据。

修改 collate_fn

参数 collate_fn 指定如何对每一 batch 的数据做预处理。在模块 torch.utils.data._utils 中,定义了两个默认的 collate_fn

  • default_convert:如果禁用 auto-batching,则用该函数将每个数据预处理为 torch.Tensor
  • default_collate:如果启用 auto-batching,则用该函数将每个 batch 预处理为 torch.Tensor
  • 5
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值