PyTorch 源码解读之 torch.utils.data:解析数据处理全流程(非常好,一篇足够)


本篇博文主要用来记录参考链接中的所学重要知识,梳理清楚。

1 Dataset

Dataset 负责对 raw data source 封装,将其封装成 Python 可识别的数据结构,其必须提供提取数据个体的接口。

  • Map-style sataset
    torch.utils.data.Dataset
    它是一种通过实现 getitem() 和 len() 来获取数据的 Dataset,它表示从(可能是非整数)索引/关键字到数据样本的映射。访问时,这样的数据集用 dataset[idx] 访问 idx 对应的数据。
  • Iterable-style dataset
  • 其他Dataset
    torch.utils.data.TensorDataset: 用于获取封装成 tensor 的数据集,每一个样本都通过索引张量来获得。

2 Sampler

torch.utils.data.Sampler 负责提供一种遍历数据集所有元素索引的方式。
特别地,len() 方法不是必要的,但是当 DataLoader 需要计算 len() 的时候必须定义,这点在其源码中也有注释加以体现。

同样,PyTorch 也在此基础上提供了其他类型的 Sampler 子类

torch.utils.data.SequentialSampler : 顺序采样样本,始终按照同一个顺序
torch.utils.data.RandomSampler: 可指定有无放回地,进行随机采样样本元素
torch.utils.data.BatchSampler: 在一个batch中封装一个其他的采样器, 返回一个 batch 大小的 index 索引

3 DataLoader

torch.utils.data.DataLoader 是 PyTorch 数据加载的核心,负责加载数据,同时支持 Map-style 和 Iterable-style Dataset,支持单进程/多进程,还可以设置 loading order, batch size, pin memory 等加载参数。其接口定义如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

对于每个参数的含义,以下给出一个表格进行对应介绍:

attributemeaningdefault valuetype
dataset加载数据的数据集Dataset
batch_size每个 batch 加载多少个样本1int
shuffle设置为 True 时,调用 RandomSampler 进行随机索引Falsebool
sampler定义从数据集中提取样本的策略。如果指定了, shuffle 参数必须为 False,(否则会和 RandomSampler 互斥)NoneSampler, Iterable
batch_sampler和 sampler 类似,但是一般传入 BatchSampler,每次返回一个 batch 大小的索引。其和 batch_size, shuffle 等参数是互斥的NoneSampler, Iterable
num_workers要用于数据加载的子进程数,0 表示将在主进程中加载数据0int
collate_fn在将 Map-style dataset 取出的数据整合成 batch 时使用,合并样本列表以形成一个 batchNonecallable
pin_memory如果为 True,则 DataLoader 在将张量返回之前将其复制到 CUDA 固定的内存中Falsebool
drop_last设置为 True 删除最后一个不完整的批次,如果该数据集大小不能被该批次大小整除。如果 False 并且数据集的大小不能被批次大小整除,那么最后一批将较小Falsebool
timeout如果为正,则为从 worker 收集 batch 的超时值,应始终为非负数。超过这个时间还没读取到数据的话就会报错0numeric
worker_init_fn如果不为 None,它将会被每个 worker 子进程调用,以 worker id ([0, num_workers - 1] 内的整形) 为输入Nonecallable
prefetch_facto每个 worker 提前加载 的 sample 数量2int
persistent_workers如果为 True,dataloader 将不会终止 worker 进程,直到 dataset 迭代完成Falsebool

3.1 三者关系 (Dataset, Sampler, Dataloader)

  • 设置 Dataset,将数据 data source 包装成 Dataset 类,暴露提取接口。
  • 设置 Sampler,决定采样方式。我们是能从 Dataset 中提取元素了,还是需要设置 Sampler 告诉程序提取 Dataset 的策略。
  • 将设置好的 Dataset 和 Sampler 传入 DataLoader,同时可以设置 shuffle, batch_size 等参数。使用 DataLoader 对象可以方便快捷地在数据集上遍历。

总结来说,即 Dataloader 负责总的调度,命令 Sampler 定义遍历索引的方式,然后用索引去 Dataset 中提取元素。于是就实现了对给定数据集的遍历。

3.2 批处理

3.2.1 自动批处理(默认)

DataLoader 支持通过参数batch_size, drop_last, batch_sampler,自动地把取出的数据整理 (collate) 成批次样本 (batch)

batch_size 和 drop_last 参数用于指定 DataLoader 如何获取 dataset 的 key。特别地,对于 map-style 类型的 dataset,用户可以选择指定 batch_sample参数,一次就生成一个 keys list

在使用 sampler 产生的 indices 获取采样到的数据时,DataLoader 使用 collate_fn 参数将样本列表整理成 batch。抽象这个过程,其表示方式大致如下

# For Map-style
for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

# For Iterable-style
dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

3.2.3 collate_fn

当关闭自动批处理 (automatic batching) 时,collate_fn 作用于单个数据样本,只是在 PyTorch 张量中转换 NumPy 数组。

当开启自动批处理 (automatic batching) 时,collate_fn 作用于数据样本列表,将输入样本整理为一个 batch,一般做下面 3 件事情

  • 添加新的批次维度(一般是第一维)
  • 它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量
  • 它保留数据结构,例如,如果每个样本都是 dict,则输出具有相同键集但批处理过的张量作为值的字典(或list,当不能转换的时候)。list, tuples, namedtuples 同样适用

自定义 collate_fn 可用于自定义排序规则,例如,将顺序数据填充到批处理的最大长度,添加对自定义数据类型的支持等。

3.3 多进程处理 (multi-process)

为了避免在加载数据时阻塞计算代码,PyTorch 提供了一个简单的开关,只需将参数设置 num_workers 为正整数即可执行多进程数据加载,设置为 0 时执行单线程数据加载。

4 预取 (prefetch)

DataLoader 通过指定 prefetch_factor (默认为 2)来进行数据的预取。

class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        ...
        self._reset(loader, first_iter=True)

    def _reset(self, loader, first_iter=False):
        ...
        # prime the prefetch loop
        for _ in range(self._prefetch_factor * self._num_workers):
            self._try_put_index()

通过源码可以看到,prefetch 功能仅适用于 多进程 加载中(下面会由多进程 dataloader 的代码分析)

5 代码详解

来看看具体的代码调用流程:

for data, label in train_loader:
    ......

for 循环会调用 dataloader 的 iter(self) 方法,以此获得迭代器来遍历 dataset

class DataLoader(Generic[T_co]):
    ...
    def __iter__(self) -> '_BaseDataLoaderIter':

        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()

iter(self) 方法中,dataloader 调用了self._get_iterator() 方法,根据 num_worker 获得迭代器,并指示进行单进程还是多进程

class DataLoader(Generic[T_co]):
    ...
    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)

为了描述清晰,我们只考虑单进程的代码。下面是 class _SingleProcessDataLoaderIter(_BaseDataLoaderIter) ,以及其父类 class _BaseDataLoaderIter(object): 的重点代码片段:

class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        # 初始化赋值一些 DataLoader 参数,
        # 以及用户输入合法性进行校验
        self._dataset = loader.dataset
        self._dataset_kind = loader._dataset_kind
        self._index_sampler = loader._index_sampler
        ...

    def __iter__(self) -> '_BaseDataLoaderIter':
        return self

    def _reset(self, loader, first_iter=False):
        self._sampler_iter = iter(self._index_sampler)
        self._num_yielded = 0
        self._IterableDataset_len_called = loader._IterableDataset_len_called

    def _next_index(self):
        return next(self._sampler_iter)  # may raise StopIteration

    def _next_data(self):
        raise NotImplementedError

    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                self._reset()
            data = self._next_data() # 重点代码行,通过此获取数据
            self._num_yielded += 1
            ...
            return data

    next = __next__  # Python 2 compatibility

    def __len__(self) -> int:
        return len(self._index_sampler) # len(_BaseDataLoaderIter) == len(self._index_sampler)

    def __getstate__(self):
        raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)

_BaseDataLoaderIter 是所有 DataLoaderIter 的父类。dataloader获得了迭代器之后,for 循环需要调用 __next__() 来获得下一个对象,从而实现遍历。通过 __next__ 方法调用 _next_data() 获取数据

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

从 _SingleProcessDataLoaderIter 的初始化参数可以看到,其在父类 _BaseDataLoaderIter 的基础上定义了 _dataset_fetcher, 并传入 _dataset, _auto_collation, _collate_fn 等参数,用于定义获取数据的方式。其具体实现会在稍后解释。

在 _next_data() 被调用后,其需要 next_index() 获取 index,并通过获得的 index 传入 _dataset_fetcher 中获取对应样本

class DataLoader(Generic[T_co]):
    ...
    @property
    def _auto_collation(self):
        return self.batch_sampler is not None

    @property
    def _index_sampler(self):
        if self._auto_collation:
            return self.batch_sampler
        else:
            return self.sampler

class _BaseDataLoaderIter(object):
    ...
    def _reset(self, loader, first_iter=False):
        self._sampler_iter = iter(self._index_sampler)
        ...

    def _next_index(self):
        # sampler_iter 来自于 index_sampler
        return next(self._sampler_iter)  # may raise StopIteration

从这里看出,dataloader 提供了 sampler (可以是batch_sampler 或者是其他 sampler 子类),然后 _SingleProcessDataLoaderIter 迭代sampler获得索引

下面我们来看看 fetcher,fetcher 需要 index 来获取元素,并同时支持 Map-style dataset(对应 _MapDatasetFetcher)和 Iterable-style dataset(对应 _IterableDatasetFetcher),使其在Dataloader内能使用相同的接口 fetch,代码更加简洁。

  • 对于 Map-style:直接输入索引 index,作为 map 的 key,获得对应的样本(即 value)
class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            # 有batch_sampler,_auto_collation就为True,
            # 就优先使用batch_sampler,对应在fetcher中传入的就是一个batch的索引
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)
  • 对于 Iterable-style: init 方法内设置了 dataset 初始的迭代器,fetch 方法内获取元素,index 其实已经没有多大作用了
class _IterableDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
        self.dataset_iter = iter(dataset)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            # 对于batch_sampler(即auto_collation==True)
            # 直接使用往后遍历并提取len(possibly_batched_index)个样本(即1个batch的样本)
            data = []
            for _ in possibly_batched_index:
                try:
                    data.append(next(self.dataset_iter))
                except StopIteration:
                    break
            if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
                raise StopIteration
        else:
            # 对于sampler,直接往后遍历并提取1个样本
            data = next(self.dataset_iter)
        return self.collate_fn(data)

最后,我们通过索引传入 fetcher,fetch 得到想要的样本。
因此,整个过程调用关系总结 如下:

loader.__iter__ --> self._get_iterator() --> class _SingleProcessDataLoaderIter --> class _BaseDataLoaderIter --> __next__() --> self._next_data() --> self._next_index() -->next(self._sampler_iter)next(iter(self._index_sampler)) --> 获得 index --> self._dataset_fetcher.fetch(index) --> 获得 data

参考链接:PyTorch 源码解读之 torch.utils.data:解析数据处理全流程

问题记录:
1.pytorch 中的Dataset这个类为什么可以调用__getitem__?
特殊方法名称
在这里插入图片描述

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值