Pytorch Dataloder之num_workers(上篇:单进程加载器)

dataloader需要加载数据,为了加速数据读取和处理,使用多进程是一个比较好的解决方法。num_workers则是控制数据加载时,子进程的数量,默认为0。

  1. num_workers=0时,表示采用单进程方法加载数据,在主进程中加载数据。

  2. num_workers=1时,表示采用多进程方法加载数据,但是只有一个子进程,使用该子进程加载数据。

  3. num_workers>1时,表示采用多进程方法加载数据,有num_workers个子进程。


代码参考:https://github.com/pytorch/pytorch/blob/159a2404bdc7fc54918c78d4bd290e5fa830dca7/torch/utils/data/dataloader.py

dataloader中关于单进程加载与多进程加载的逻辑代码如下:

1. 在从dataloader中取数据时,先调用dataloader的__iter__方法,__iter__方法中,则会优先调用self._get_iterator()方法,返回一个迭代器的实例化对象。

    def __iter__(self) -> '_BaseDataLoaderIter':
        # When using a single worker the returned iterator should be
        # created everytime to avoid reseting its state
        # However, in the case of a multiple workers iterator
        # the iterator is only created once in the lifetime of the
        # DataLoader object so that workers can be reused
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()

2. _get_iterator()方法通过判断self.num_workers参数,返回单进程迭代器_SingleProcessDataLoaderIter还是多进程迭代器_MultiProcessingDataLoaderIter。

    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)    # 单进程迭代器
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)  # 多进程迭代器

3. 两种迭代器都继承自基类_BaseDataLoaderIter。两种迭代器实现的功能都是通过__next__()方法,返回一个batch的数据。首先看基类_BaseDataLoaderIter:

3.1 在我们从dataloader中取数据时:for batch in dataloder中:会按照next(iter(dataloder))进行调用。iter(dataloder)返回的是_BaseDataLoaderIter类的实例化,我们称为iterator。next方法实际是作用在_BaseDataLoaderIter类,因此_BaseDataLoaderIter实现了__next__方法,最终返回一个batch的数据。

class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        ***
        
    def _next_data(self):    # 3.由子类实现,由此方法返回最终数据,单进程/多进程实现方法不一样,重点关注。
        raise NotImplementedError

    def __next__(self) -> Any:   # 1.返回最终的一个batch的数据。
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                self._reset()
            data = self._next_data()    # 2.调用子类的_next_data()方法获取数据。
            self._num_yielded += 1
            if self._dataset_kind == _DatasetKind.Iterable and \
                    self._IterableDataset_len_called is not None and \
                    self._num_yielded > self._IterableDataset_len_called:
                warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                            "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                                  self._num_yielded)
                if self._num_workers > 0:
                    warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                                 "IterableDataset replica at each worker. Please see "
                                 "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
                warnings.warn(warn_msg)
            return data
    
    ***

3.2 先看_SingleProcessDataLoaderIter,单进程迭代器相比多进程迭代器整体逻辑会简单些,本文先讲单进程迭代器,后续会出一篇单独讲多进程迭代器。主要看_next_data()方法,用于返回一个batch的数据。

class DataLoader(Generic[T_co]):

    @property
    def _auto_collation(self):    # 6.有batch_sampler则返回True
        return self.batch_sampler is not None
        
    @property
    def _index_sampler(self):
        if self._auto_collation:  # 5.有batch_sampler则返回batch_sampler
            return self.batch_sampler
        else:
            return self.sampler


class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        ***
        self._index_sampler = loader._index_sampler # 4.通过调用dataloader的_index_sampler方法,获取batch_sampler
        self._sampler_iter = iter(self._index_sampler) # 3.self._sampler_iter是一个batch_sampler迭代器

    def _next_index(self):    # 2.self._sampler_iter是一个batch_sampler迭代器,返回一个batch数据的索引。
        return next(self._sampler_iter)


class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # 1.调用基类的_next_index()方法,获取一个batch数据的索引。
        data = self._dataset_fetcher.fetch(index)  # 2.利用DatasetFetcher,调用fetch()方法,将索引转换成一个batch的数据。
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

3.3 这里扩展讲解一下DatasetFetcher,代码不多,看注释,主要是_MapDatasetFetcher的fetch()方法。

class DataLoader(Generic[T_co]):

    @property
    def _auto_collation(self):    # 6.有batch_sampler则返回True
        return self.batch_sampler is not None
        
    @property
    def _index_sampler(self):
        if self._auto_collation:  # 5.有batch_sampler则返回batch_sampler
            return self.batch_sampler
        else:
            return self.sampler


class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        ***
        self._index_sampler = loader._index_sampler # 4.通过调用dataloader的_index_sampler方法,获取batch_sampler
        self._sampler_iter = iter(self._index_sampler) # 3.self._sampler_iter是一个batch_sampler迭代器

    def _next_index(self):    # 2.self._sampler_iter是一个batch_sampler迭代器,返回一个batch数据的索引。
        return next(self._sampler_iter)


class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # 1.调用基类的_next_index()方法,获取一个batch数据的索引。
        data = self._dataset_fetcher.fetch(index)  # 2.利用DatasetFetcher,调用fetch()方法,将索引转换成一个batch的数据。
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

嗨,欢迎大家关注我的公众号《CV之路》,一起讨论问题,一起学习进步~。也欢迎大家关注我的GitHub仓库,我出的所有博文教程都是无偿分享的,只求个关注与Star~,多谢大家支持!

GitHub - gy-7/CV-Road (后续教程相关所有代码都会维护到此仓库)

 

在使用PyTorch进行数据加载时,正确设置DataLoadernum_workers参数是确保多进程稳定运行的关键。当遇到'RuntimeError: DataLoader worker (pid(s) 9528, 8320) exited unexpectedly'这类错误时,可能是因为工作进程遇到了无法解决的问题或资源冲突。 参考资源链接:[RuntimeError: DataLoader worker (pid(s) 9528, 8320) exited unexpectedly](https://wenku.csdn.net/doc/64532556ea0840391e771115?spm=1055.2569.3001.10343) 首先,需要了解num_workers参数的作用,它定义了用于数据加载的工作进程数。设置num_workers为0意味着使用主进程进行加载,而不是多进程。如果你的目的是利用多进程来加速数据加载,合理的num_workers值应该是CPU核心数减去1或2,以留出足够的资源给主进程和其他可能的后台进程使用。 其次,错误可能是由于以下几个原因造成的: 1. dataset对象中自定义了__get_item__方法,在该方法中进行了诸如随机操作等不安全的操作,这可能会导致进程间的数据不一致或状态损坏。 2. 在自定义的dataset类中使用了Python的multiprocessing库中的锁或其他同步机制,与DataLoader的内部机制冲突。 为了解决这类问题,你可以尝试以下步骤: - 确保dataset的__get_item__方法是线程安全的,避免任何不安全的操作。 - 检查dataset中是否有使用到锁或其他同步机制,确保它们不会与DataLoader的工作进程产生冲突。 - 如果你的数据集是从磁盘读取,确保I/O操作是高效的,避免I/O瓶颈导致工作进程异常退出。 此外,可以通过修改DataLoader的默认工作进程初始化函数,使用spawn或forkserver方法来启动工作进程,这可能会解决一些在fork方法下遇到的问题。 通过上述步骤,你应该能够有效解决多进程数据加载时出现的意外退出问题。如果你希望进一步深入理解PyTorch DataLoader的工作机制和多进程数据加载的优化,建议查阅官方文档或参考《RuntimeError: DataLoader worker (pid(s) 9528, 8320) exited unexpectedly》这篇深入的教程,它将为你提供更多背景知识和实践指导。 参考资源链接:[RuntimeError: DataLoader worker (pid(s) 9528, 8320) exited unexpectedly](https://wenku.csdn.net/doc/64532556ea0840391e771115?spm=1055.2569.3001.10343)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

gy-7

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值