本文章将首先介绍 Dataset、DataLoader、Sampler、BatchSampler、collate_fn 的概念,然后从源码角度解读 DataLoader 与这些模块的关系。
如果你熟悉基本概念,可以直接从最后章节开始阅读。
1 Dataset
Pytorch 支持两种类型的数据集 Map-style Dataset 和 Iterable-style Dataset,提供表示数据集的抽象类,任何自定义的 Dataset 都需要继承该类并覆写相关方法
1.1 Map-style Dataset
- 需要继承
torch.utils.data.Dataset
- 需要覆写两个方法
__getitem__(self, index)
__len__(self)
- 本质上构建了 index 到 data 的映射,
dataset[idx]
返回数据集中第 idx 个 item。值得注意的是,idx可以不是 int 类型 len(dataset)
返回数据集的大小
1.2 Iterable-style Dataset
- 需要继承
torch.utils.data.IterableDataset
- 需要覆写一个方法
__iter__(self)
- 本质上是一个可迭代对象,通过
next(dataset)
调用__iter__(self)
方法返回数据集的下一个 item
2 Sampler
Sampler 本质上是迭代器,用于产生数据集的索引值序列。
2.1 内置的 Sampler
PyTorch 提供了多种内置的 Sampler:
-
- SequentialSampler
- RandomSampler
- WeightedSampler
- SubsetRandomSampler
2.2 自定义 Sampler
- 需要继承
torch.utils.data.Sampler
- 需要覆写
__iter__(self)
方法,返回值必须可迭代
3 BatchSampler
BatchSampler 将 Sampler 采样得到的索引值进行合并,当数量等于一个 batch 大小后就将这一批的索引值返回。
4 collect_fn
Pytorch 内置的 default_collate()
会将 NumPy arrays 转换为 PyTorch tensors
5 Dataloader
5.1 参数解析
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device=' ' )
- 如果
sampler
和batch_sampler
都为 None,batch_sampler
将使用内置的BatchSampler
,sampler
有两种情况:- 若
shuffle=True
,则sampler=RandomSampler(dataset)
- 若
shuffle=False
,则sampler=SequentialSampler(dataset)
- 若
- 如果使用自定义
batch_sampler
- 不能指定
sampler
,batch_size
,shuffle
,drop_last
sampler
作为batch_sampler
的参数传入
- 不能指定
- 如果使用自定义
sampler
,则不能指定shuffle
,这是因为shuffle
仅用于缺省sampler
时选择 pytorch 内置的sampler
- 如果
dataset
使用 IterableDataset,则不能指定sampler
,batch_sampler
sampler
将使用_InfiniteConstantSampler()
,这是一个 dummy sampler,调用 __iter__(self) 方法永远返回 None- 可以指定
batch_size
,drop_last
,batch_sampler
将使用内置的 BatchSampler,
5.2 Automatic batching 与 Disable automatic batching
DataLoader 有两种模式 Automatic batching 和 Disable automatic batching
- 当
batch_size
和drop_last
均为 None 的时候,使用 Disable automatic batching 模式 - 否则使用 Automatic batching 模式
两种模式都有不同的逻辑来处理 Map-style Dataset 和 Iterable-style Dataset
Pytorch 在torch.utils.data._utils.fetch
中构建了两种模式下处理 Map-style Dataset 和 Iterable-style Dataset 的逻辑:
class _BaseDatasetFetcher:
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
self.dataset = dataset
self.auto_collation = auto_collation
self.collate_fn = collate_fn
self.drop_last = drop_last
def fetch(self, possibly_batched_index):
raise NotImplementedError()
class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super().__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset)
self.ended = False
def fetch(self, possibly_batched_index):
if self.ended:
raise StopIteration
if self.auto_collation:
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
self.ended = True
break
if len(data) == 0 or (
self.drop_last and len(data) < len(possibly_batched_index)
):
raise StopIteration
else:
data = next(self.dataset_iter)
return self.collate_fn(data)
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
if self.auto_collation:
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
data = self.dataset.__getitems__(possibly_batched_index)
else:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
Automatic batching 的处理逻辑可以简化为:
sampler
采样dataset
batch_sampler
依次将sampler
采样得到的 indices 进行合并,当数量等于batch_size
时将这个 batch 的 indices 返回。drop_last
决定是否丢弃最后不足一个 batch 的部分- DataLoader 依次按照
batch_sampler
提供的 batch indices 将数据从dataset
中读出,传给collate_fn
进行整理,返回 Tensor
# map-style dataset
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
# iterable-style dataset
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
Disable automatic batching 的处理逻辑可以简化为:
sampler
采样dataset
- DataLoader 依次按照
sampler
提供的 indices 将数据从dataset
中读出,传给collate_fn
进行整理,返回 Tensor
# map-style dataset
for index in sampler:
yield collate_fn(dataset[index])
# iterable-style dataset
for data in iter(dataset):
yield collate_fn(data)
Pytorch 源码解读:Dataset、DataLoader、Sampler、BatchSampler、collate_fn - 知乎