Pytorch Dataloder之num_workers(上篇:单进程加载器)

dataloader需要加载数据,为了加速数据读取和处理,使用多进程是一个比较好的解决方法。num_workers则是控制数据加载时,子进程的数量,默认为0。

  1. num_workers=0时,表示采用单进程方法加载数据,在主进程中加载数据。

  2. num_workers=1时,表示采用多进程方法加载数据,但是只有一个子进程,使用该子进程加载数据。

  3. num_workers>1时,表示采用多进程方法加载数据,有num_workers个子进程。


代码参考:https://github.com/pytorch/pytorch/blob/159a2404bdc7fc54918c78d4bd290e5fa830dca7/torch/utils/data/dataloader.py

dataloader中关于单进程加载与多进程加载的逻辑代码如下:

1. 在从dataloader中取数据时,先调用dataloader的__iter__方法,__iter__方法中,则会优先调用self._get_iterator()方法,返回一个迭代器的实例化对象。

    def __iter__(self) -> '_BaseDataLoaderIter':
        # When using a single worker the returned iterator should be
        # created everytime to avoid reseting its state
        # However, in the case of a multiple workers iterator
        # the iterator is only created once in the lifetime of the
        # DataLoader object so that workers can be reused
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()

2. _get_iterator()方法通过判断self.num_workers参数,返回单进程迭代器_SingleProcessDataLoaderIter还是多进程迭代器_MultiProcessingDataLoaderIter。

    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)    # 单进程迭代器
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)  # 多进程迭代器

3. 两种迭代器都继承自基类_BaseDataLoaderIter。两种迭代器实现的功能都是通过__next__()方法,返回一个batch的数据。首先看基类_BaseDataLoaderIter:

3.1 在我们从dataloader中取数据时:for batch in dataloder中:会按照next(iter(dataloder))进行调用。iter(dataloder)返回的是_BaseDataLoaderIter类的实例化,我们称为iterator。next方法实际是作用在_BaseDataLoaderIter类,因此_BaseDataLoaderIter实现了__next__方法,最终返回一个batch的数据。

class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        ***
        
    def _next_data(self):    # 3.由子类实现,由此方法返回最终数据,单进程/多进程实现方法不一样,重点关注。
        raise NotImplementedError

    def __next__(self) -> Any:   # 1.返回最终的一个batch的数据。
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                self._reset()
            data = self._next_data()    # 2.调用子类的_next_data()方法获取数据。
            self._num_yielded += 1
            if self._dataset_kind == _DatasetKind.Iterable and \
                    self._IterableDataset_len_called is not None and \
                    self._num_yielded > self._IterableDataset_len_called:
                warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                            "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                                  self._num_yielded)
                if self._num_workers > 0:
                    warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                                 "IterableDataset replica at each worker. Please see "
                                 "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
                warnings.warn(warn_msg)
            return data
    
    ***

3.2 先看_SingleProcessDataLoaderIter,单进程迭代器相比多进程迭代器整体逻辑会简单些,本文先讲单进程迭代器,后续会出一篇单独讲多进程迭代器。主要看_next_data()方法,用于返回一个batch的数据。

class DataLoader(Generic[T_co]):

    @property
    def _auto_collation(self):    # 6.有batch_sampler则返回True
        return self.batch_sampler is not None
        
    @property
    def _index_sampler(self):
        if self._auto_collation:  # 5.有batch_sampler则返回batch_sampler
            return self.batch_sampler
        else:
            return self.sampler


class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        ***
        self._index_sampler = loader._index_sampler # 4.通过调用dataloader的_index_sampler方法,获取batch_sampler
        self._sampler_iter = iter(self._index_sampler) # 3.self._sampler_iter是一个batch_sampler迭代器

    def _next_index(self):    # 2.self._sampler_iter是一个batch_sampler迭代器,返回一个batch数据的索引。
        return next(self._sampler_iter)


class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # 1.调用基类的_next_index()方法,获取一个batch数据的索引。
        data = self._dataset_fetcher.fetch(index)  # 2.利用DatasetFetcher,调用fetch()方法,将索引转换成一个batch的数据。
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

3.3 这里扩展讲解一下DatasetFetcher,代码不多,看注释,主要是_MapDatasetFetcher的fetch()方法。

class DataLoader(Generic[T_co]):

    @property
    def _auto_collation(self):    # 6.有batch_sampler则返回True
        return self.batch_sampler is not None
        
    @property
    def _index_sampler(self):
        if self._auto_collation:  # 5.有batch_sampler则返回batch_sampler
            return self.batch_sampler
        else:
            return self.sampler


class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        ***
        self._index_sampler = loader._index_sampler # 4.通过调用dataloader的_index_sampler方法,获取batch_sampler
        self._sampler_iter = iter(self._index_sampler) # 3.self._sampler_iter是一个batch_sampler迭代器

    def _next_index(self):    # 2.self._sampler_iter是一个batch_sampler迭代器,返回一个batch数据的索引。
        return next(self._sampler_iter)


class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # 1.调用基类的_next_index()方法,获取一个batch数据的索引。
        data = self._dataset_fetcher.fetch(index)  # 2.利用DatasetFetcher,调用fetch()方法,将索引转换成一个batch的数据。
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

嗨,欢迎大家关注我的公众号《CV之路》,一起讨论问题,一起学习进步~。也欢迎大家关注我的GitHub仓库,我出的所有博文教程都是无偿分享的,只求个关注与Star~,多谢大家支持!

GitHub - gy-7/CV-Road (后续教程相关所有代码都会维护到此仓库)

 

  • 7
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

gy-7

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值