Pytorch 源码解读:Dataset、DataLoader、Sampler、BatchSampler、collate_fn

本文章将首先介绍 Dataset、DataLoader、Sampler、BatchSampler、collate_fn 的概念,然后从源码角度解读 DataLoader 与这些模块的关系。

如果你熟悉基本概念,可以直接从最后章节开始阅读。

1 Dataset

Pytorch 支持两种类型的数据集 Map-style Dataset 和 Iterable-style Dataset,提供表示数据集的抽象类,任何自定义的 Dataset 都需要继承该类并覆写相关方法

1.1 Map-style Dataset

  • 需要继承torch.utils.data.Dataset
  • 需要覆写两个方法
    • __getitem__(self, index)
    • __len__(self)
  • 本质上构建了 index 到 data 的映射,dataset[idx] 返回数据集中第 idx 个 item。值得注意的是,idx可以不是 int 类型
  • len(dataset) 返回数据集的大小

1.2 Iterable-style Dataset

  • 需要继承torch.utils.data.IterableDataset
  • 需要覆写一个方法 __iter__(self)
  • 本质上是一个可迭代对象,通过 next(dataset) 调用 __iter__(self) 方法返回数据集的下一个 item

2 Sampler

Sampler 本质上是迭代器,用于产生数据集的索引值序列。

2.1 内置的 Sampler

PyTorch 提供了多种内置的 Sampler:

    • SequentialSampler
    • RandomSampler
    • WeightedSampler
    • SubsetRandomSampler

2.2 自定义 Sampler

  • 需要继承 torch.utils.data.Sampler
  • 需要覆写 __iter__(self) 方法,返回值必须可迭代

3 BatchSampler

BatchSampler 将 Sampler 采样得到的索引值进行合并,当数量等于一个 batch 大小后就将这一批的索引值返回。

4 collect_fn

Pytorch 内置的 default_collate() 会将 NumPy arrays 转换为 PyTorch tensors

5 Dataloader

5.1 参数解析

class torch.utils.data.DataLoader(datasetbatch_size=1shuffle=Nonesampler=Nonebatch_sampler=Nonenum_workers=0collate_fn=Nonepin_memory=Falsedrop_last=Falsetimeout=0worker_init_fn=Nonemultiprocessing_context=Nonegenerator=None*prefetch_factor=Nonepersistent_workers=Falsepin_memory_device=' ' )

  • 如果sampler 和 batch_sampler 都为 None,batch_sampler 将使用内置的 BatchSamplersampler 有两种情况:
    • 若 shuffle=True,则 sampler=RandomSampler(dataset)
    • 若 shuffle=False,则 sampler=SequentialSampler(dataset)
  • 如果使用自定义 batch_sampler
    • 不能指定 sampler , batch_sizeshuffledrop_last
    • sampler 作为 batch_sampler的参数传入
  • 如果使用自定义 sampler,则不能指定 shuffle,这是因为 shuffle 仅用于缺省 sampler 时选择 pytorch 内置的 sampler
  • 如果 dataset 使用 IterableDataset,则不能指定 samplerbatch_sampler
    • sampler 将使用 _InfiniteConstantSampler(),这是一个 dummy sampler,调用 __iter__(self) 方法永远返回 None
    • 可以指定 batch_sizedrop_lastbatch_sampler 将使用内置的 BatchSampler,

5.2 Automatic batching 与 Disable automatic batching

DataLoader 有两种模式 Automatic batching 和 Disable automatic batching

  • 当 batch_size 和 drop_last 均为 None 的时候,使用 Disable automatic batching 模式
  • 否则使用 Automatic batching 模式

两种模式都有不同的逻辑来处理 Map-style Dataset 和 Iterable-style Dataset

Pytorch 在torch.utils.data._utils.fetch 中构建了两种模式下处理 Map-style Dataset 和 Iterable-style Dataset 的逻辑:

class _BaseDatasetFetcher:
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        self.dataset = dataset
        self.auto_collation = auto_collation
        self.collate_fn = collate_fn
        self.drop_last = drop_last

    def fetch(self, possibly_batched_index):
        raise NotImplementedError()


class _IterableDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super().__init__(dataset, auto_collation, collate_fn, drop_last)
        self.dataset_iter = iter(dataset)
        self.ended = False

    def fetch(self, possibly_batched_index):
        if self.ended:
            raise StopIteration

        if self.auto_collation:
            data = []
            for _ in possibly_batched_index:
                try:
                    data.append(next(self.dataset_iter))
                except StopIteration:
                    self.ended = True
                    break
            if len(data) == 0 or (
                self.drop_last and len(data) < len(possibly_batched_index)
            ):
                raise StopIteration
        else:
            data = next(self.dataset_iter)
        return self.collate_fn(data)


class _MapDatasetFetcher(_BaseDatasetFetcher):
    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
                data = self.dataset.__getitems__(possibly_batched_index)
            else:
                data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

Automatic batching 的处理逻辑可以简化为:

  1. sampler 采样 dataset
  2. batch_sampler 依次将 sampler 采样得到的 indices 进行合并,当数量等于 batch_size 时将这个 batch 的 indices 返回。drop_last 决定是否丢弃最后不足一个 batch 的部分
  3. DataLoader 依次按照 batch_sampler 提供的 batch indices 将数据从 dataset 中读出,传给 collate_fn 进行整理,返回 Tensor
# map-style dataset
for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

# iterable-style dataset
dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

Disable automatic batching 的处理逻辑可以简化为:

  1. sampler 采样 dataset
  2. DataLoader 依次按照 sampler 提供的 indices 将数据从 dataset 中读出,传给 collate_fn 进行整理,返回 Tensor
# map-style dataset
for index in sampler:
    yield collate_fn(dataset[index])

# iterable-style dataset
for data in iter(dataset):
    yield collate_fn(data)




Pytorch 源码解读:Dataset、DataLoader、Sampler、BatchSampler、collate_fn - 知乎

  • 16
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值