PyTorch 中的 torch.utils.data
解析
在 PyTorch 中,提供了一个处理数据集的工具包 torch.utils.data
。这里来简单介绍这个包的结构。以下内容翻译和整理自 PyTorch 官方文档。
概述
PyTorch 数据集处理包 torch.utils.data
的核心是 DataLoader
类。该类的构造函数签名为
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
它构造一个 可迭代对象 loader
,代表经过 “加工” 后的数据集。所谓的 “加工” 过程,是由构造函数参数表指定的,它包括:
- 设置数据的加载顺序(通过修改
shuffle
或sampler
参数) - 对数据进行 batching 处理(通过修改
batch_size
,batch_sampler
,collate_fn
及drop_last
参数) - 实现 multi-process loading,memory pinning 等(此处不涉及)
一旦构造了 DataLoader
对象 loader
,就可以用
for data in loader:
# data 是数据集中的一组数据,且已转换成 Tensor
来加载数据。缺省情况下,PyTorch 会对数据进行 auto-batching,此时 data
对应一个 batch 的数据。
可以将 loader
理解成一个 生成器,其定义按情况可分为:(出现一些概念之后都会解释)
## 启用 auto-batching
# 对 map-style 数据集
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
# 对 iterable-style 数据集
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
## 不启用 auto-batching (设置 batch_size=None 和 batch_sampler=None)
# 对 map-style 数据集
for index in sampler:
yield collate_fn(dataset[index])
# 对 iterable-style 数据集
for data in iter(dataset):
yield collate_fn(data)
数据集
DataLoader
构造函数中的必需参数 dataset
代表一个数据集。数据集主要分为两种:
- map-style 数据集:它是
torch.utils.data.Dataset
的子类,重载了__getitem__
和__len__
运算符,可以随机访问数据集中的数据 - iterable-style 数据集:它是
torch.utils.data.IterableDataset
的子类,是可迭代对象
数据加载顺序
手动定义 sampler
可以通过指定 sampler
参数来手动设置加载顺序。一个 sampler
是可迭代对象,其迭代的每一个值表示下一个待加载数据的 key/index。它应当实例化泛型类 torch.utils.data.Sampler[int]
的一个子类,并且重载 __iter__
和 __len__
函数,具体地讲:
- 构造函数
__init__(self, data_source, *args)
必须提供一个重载了__len__
的数据集data_source
作为参数 __iter__
返回一个整型迭代器,其每迭代一次的返回值为下一个待加载数据的 key/index__len__
返回要加载的数据总数
但是要注意,只有 map-style 数据集才可定义 sampler,因为 iterable-style 不一定支持随机访问。
使用内置 sampler
在模块 torch.utils.data.sampler
中定义了一些内置的 sampler,通常来说已经够用了。在缺省 sampler
参数的情况下,如果指定参数 shuffle=False
将使用 SequentialSampler
,即按顺序加载整个数据集;如果指定 shuffle=True
则使用 RandomSampler
,即随机打乱数据后加载整个数据集。但是注意,不允许同时指定 sampler
参数和 shuffle
参数。
另外一些 sampler 可以参见模块源代码。
数据的 batching
在训练神经网络的时候经常需要将数据分成 mini-batch。PyTorch 本身提供了 auto-batching 的功能,也可以通过修改参数 batch_size
,batch_sampler
,drop_last
及 collate_fn
进行自定义 batching。
使用 batch sampler
在定义了 batching 后,PyTorch 会一次性输入多个(数量为 batch_size
)数据。这时候需使用 batch sampler 来取代普通的 sampler。
通过指定 batch_sampler
参数,可以手动实现想要的 batch sampler。一个 batch_sampler
是 torch.utils.data.BatchSampler
的实例。在 PyTorch 源代码中,该类继承了 Sampler[List[int]]
,并且封装了一个 sampler
。
- 构造函数签名为
__init__(self, sampler, batch_size: int, drop_last: bool)
,其中sampler
是一个可迭代对象,代表被封装的 samplerbatch_size
代表每个 batch 的数据量drop_last
表示要不要把最后一个不足batch_size
的 batch 丢掉
__iter__
返回一个迭代器,它每迭代一次,返回一个List[int]
,表示下一个 batch 的 key/index 列表__len__
返回 batch 总数
请注意,
- 如果自定义了
batch_sampler
,那么不能再指定sampler
,shuffle
,batch_size
和drop_last
参数 - 如果没有指定
batch_sampler
参数,但batch_size
不为None
,则DataLoader
构造函数自动使用自定义的sampler
或由shuffle
指定的内置 sampler,以及batch_size
和drop_last
参数封装 batch sampler - 如果既没有指定
batch_sampler
参数,又设置batch_size
为None
,则禁用 auto-batching,每加载一次输出的是单个数据。
修改 collate_fn
参数 collate_fn
指定如何对每一 batch 的数据做预处理。在模块 torch.utils.data._utils
中,定义了两个默认的 collate_fn
:
default_convert
:如果禁用 auto-batching,则用该函数将每个数据预处理为torch.Tensor
default_collate
:如果启用 auto-batching,则用该函数将每个 batch 预处理为torch.Tensor