PyTorch DataLoader源码分析(三)

经过前面的铺垫,DataLoader的整体架构和依赖部件都已分析完毕:
PyTorch Dataloader源码分析(一)
PyTorch DataLoader源码分析(二)

三、DataLoader迭代器详解

这一章主要介绍DataLoader的核心部分——_SingleProcessDataLoaderIter和_MultiProcessDataLoaderIter。两者的区别顾名思义,一个用于单进程,一个用于多进程。

从代码实现上看,当用户选择的num_workers等于0时,
DataLoader返回_SingleProcessDataLoaderIter迭代器,否则返回_MultiProcessDataLoaderIter迭代器。

class DataLoader(object):
    ... ...
    def __iter__(self):
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)

前面介绍过DataLoaderIter的工作流程:
DataLoader工作流程
无论是_SingleProcessDataLoaderIter还是_MultiProcessDataLoaderIter,工作流程都如上图,只不过各个部件的执行单元和执行时序有差别(后面会解释)。

1、_BaseDataLoaderIter父类

class _BaseDataLoaderIter(object):
    def __init__(self, loader):
        self._dataset = loader.dataset
        self._dataset_kind = loader._dataset_kind
        self._IterableDataset_len_called = loader._IterableDataset_len_called
        self._auto_collation = loader._auto_collation
        self._drop_last = loader.drop_last
        self._index_sampler = loader._index_sampler
        self._num_workers = loader.num_workers
        self._pin_memory = loader.pin_memory and torch.cuda.is_available()
        self._timeout = loader.timeout
        self._collate_fn = loader.collate_fn
        self._sampler_iter = iter(self._index_sampler)
        self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
        self._num_yielded = 0

    def __iter__(self):
        return self

    def _next_index(self):
        return next(self._sampler_iter)  # may raise StopIteration

    def _next_data(self):
        raise NotImplementedError

    def __next__(self):
        data = self._next_data()
        self._num_yielded += 1
        if self._dataset_kind == _DatasetKind.Iterable and \
                self._IterableDataset_len_called is not None and \
                self._num_yielded > self._IterableDataset_len_called:
            warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                        "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                              self._num_yielded)
            if self._num_workers > 0:
                warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                             "IterableDataset replica at each worker. Please see "
                             "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
            warnings.warn(warn_msg)
        return data

    next = __next__  # Python 2 compatibility

    def __len__(self):
        return len(self._index_sampler)

    def __getstate__(self):
        # TODO: add limited pickling support for sharing an iterator
        # across multiple threads for HOGWILD.
        # Probably the best way to do this is by moving the sample pushing
        # to a separate thread and then just sharing the data queue
        # but signalling the end is tricky without a non-blocking API
        raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)

_BaseDataLoaderIter中最重要的就是__next__方法,根据迭代器协议,遍历DataLoader的for循环每次都会调用其返回迭代器的__next__方法。在_BaseDataLoaderIter的__next__方法中,会固定调用__next_data方法获得数据,这么做应该是为了复用代码。因此,在_SingleProcessDataLoaderIter和_MultiProcessDataLoaderIter中,关注的重点便是其各自的__next_data方法。

2、_SingleProcessDataLoaderIter迭代器

_SingleProcessDataLoaderIter的实现非常简洁。对应到流程图上,‘self._next_index()’负责从sampler中拿到index,‘self._dataset_fetcher.fetch(index)’负责用index获得tensor,而’_utils.pin_memory.pin_memory(data)‘负责将pageble tensor转换成pinned tensor。这几个步骤从时序上来看是串行的,都由主进程执行,总耗时为所有部件耗时的总和。

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

3、_MultiProcessDataLoaderIter迭代器

_MultiProcessDataLoaderIter的工作流程和上图一样,没有变化,区别在于各部件的工作时序:Fetcher和Pin_memory这两步由单独的进程和线程执行,和主进程可以并行,目的便是使得DataLoader的耗时和网络的计算可以overlap,从而加快训练过程。之所以选择Fetcher和Pin_memory这两个步骤做并行,是因为DataLoader中主要的耗时操作(CPU bound和IO bound)都在这两个步骤中。

虽然工作流程没有变化,由于加入了多进程/多线程,时序理解起来还是略显复杂。在具体分析代码前,先通过下图大致展示其内部workflow以及重要数据结构:
在这里插入图片描述
构成_MultiProcessDataLoaderIter主体部分的主要是多个进程/线程和多个queue,进程/线程分别为:
主进程(主线程) main_thread。每次从data_queue中取一个数据,然后通过sampler获得一个index,发给对应index_queue。

  • 主进程(pin_memory线程) pin_memory_thread。每次从worker_result_queue中取一个数据,将其从pageble tensor转换成pinned tensor,然后送到data_queue中。
  • 子进程(num_worker个子进程) worker_1~n_process。每个进程负责:每次从index_queue中取一个下标数据,先将其从磁盘load到内存中,然后做一系列用户定义的前处理操作,完成后将其送到worker_result_queue中。

多个queue充当这多个进程/线程之间生产-消费关系的缓冲:

  • index_queue。存放数据为(send_idx, index),由main_thread生产,worker_1~n_process消费。其中send_idx是main_thread维护的记录任务顺序和数量的计数器,每发送一个index到index_queue中,send_idx便会加一,具体用途后续解释。
  • worker_result_queue。存放数据为(send_idx, pageble tensor),由worker_1~n_process产生,pin_memory_thread消费。
  • data_queue。存放数据为(send_idx, pinned tensor),由pin_memory_thread产生,main_thread消费。

这多个进程/线程各司其职,相互之间唯一的联系便是多个queue队列,当某个队列为空时,该队列的消费线程/进程便会被阻塞,符合典型的生产-消费模型。下面通过源码详细分析一下内部细节。

先看下_MultiProcessDataLoaderIter代码的主体结构,有个全局认识:

class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        # 调用时机:用户初始化DataLoader对象时,若num_worker > 0,便会构造_MultiProcessDataLoaderIter对象,进入该__init__方法。
        # 职责:从DataLoader对象中获得用户参数,初始化numworker个子进程、pin_memory线程以及多个队列queue,
        #      并下发2*num_worker数量的任务(即index)。

    def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
        # 调用时机:由_get_data方法调用。
        # 职责:从data_queue中取数据,并对各种异常进行处理。

    def _get_data(self):
        # 调用时机:由_next_data方法调用。
        # 职责:调用_try_get_data方法获取数据,并检查数据是否获取成功。
        
    def _next_data(self):
        # 调用时机:用户每次对DataLoader对象进行for循环迭代时,都会进入该方法。
        # 职责:作为迭代器的入口,该方法负责返回用户需要的数据,每次的工作流程如下:
        #         1、检查本次需要获取的数据是否已在缓存中(不在queue中),若在则直接从缓存取。
        #         2、若不在缓存中,则调用_get_data获取数据。
        #         3、若该数据不是本次应该等待的数据(即该数据的idx不等于ecvd_idx),则存到缓存中,返回第一步,否则进入下一步。
        #         4、获取数据后,调用_process_data做近一步处理并返回数据。

    def _try_put_index(self):
        # 调用时机:由_process_data方法调用。
        # 职责:1、从sampler对象中获得index(调用父类的_next_index方法)
        #      2、将(send_idx, index)送入对应的index_queue中
        #      3、send_idx加一

    def _process_data(self, data):
        # 调用时机:由_next_data方法调用。
        # 职责:先对rcvd_idx加一,再调用_try_put_index方法,然后返回之前从_get_data中获取的数据。

接下来针对这个几方法逐个进行解析(只抓主要流程,与shutdown处理相关的逻辑暂时略过)。

(1) __init__方法
def __init__(self, loader):
        super(_MultiProcessingDataLoaderIter, self).__init__(loader)
        ... ...
        # 1、创建多进程/线程间用于维护数据顺序的数据结构
        self._send_idx = 0  # idx of the next task to be sent to workers
        self._rcvd_idx = 0  # idx of the next task to be returned in __next__
        self._task_info = {}

        # 2、根据用户参数将num_worker个子进程和pin_memory线程创建并初始化
        self._index_queues = []
        self._workers = []
        for i in range(self._num_workers):
            index_queue = multiprocessing_context.Queue()
            # index_queue.cancel_join_thread()
            w = multiprocessing_context.Process(... ...)
            w.daemon = True
            w.start()
            self._index_queues.append(index_queue)
            self._workers.append(w)

        if self._pin_memory:
            self._data_queue = queue.Queue()
            pin_memory_thread = threading.Thread(... ...)
            pin_memory_thread.daemon = True
            pin_memory_thread.start()
        else:
            self._data_queue = self._worker_result_queue

        # 3、发送2*num_worker个index,让多进程/线程工作起来
        for _ in range(2 * self._num_workers):
            self._try_put_index()

在_MultiProcessDataLoaderIter对象主要的成员结构中,多个queue和进程/线程在前面已经介绍过各自用途,并梳理过它们之间的数据流关系。但是有三个重要的成员还没谈到,那就是send_idx、rcvd_idx和task_info。

在介绍这三个成员的用途前,我们先思考一个问题 :“_MultiProcessDataLoaderIter和_SingleProcessDataLoaderIter在功能上是等价的吗?”。

使用多进程/线程除了在性能上有较大区别外,在功能上也会产生意外的区别:在_SingleProcessDataLoaderIter中,所有操作都是串行的,先通过sampler对象拿到index,再用index去load对应数据。只要sampler产生的index序列一致,每次拿到的数据序列便一致。这个特性我们暂且称之为“顺序一致性”。换到多进程/线程场景中,“顺序一致性”就难以维持了。虽然主进程中main_thread拿到的index仍是串行的,可以保证发送index的”顺序一致性“,但使用index去load数据的操作是由多个子进程完成,严格来说,这num_worker个子进程除了load数据,还要做数据预处理,这两步很耗时,分别属于IO密集型和CPU密集型任务,就算每个子进程的负载(待处理数据量)一样,但耗时可能相差甚大(某个进程在占据CPU的过程中都可能被打断而切换,除非绑核),因此,num_worker个子进程的执行速度是无法保证的,这就导致worker_result_queue中的数据不一定是按照main_thread中产生的index的顺序。

为了解决在多进程/线程下导致的这种“顺序不一致”问题,便引入了send_idx、rcvd_idx和task_info成员。那具体如何解决呢?一个朴素的想法是“为每个index和tensor数据都附加一个id,用以标识该数据对应main_thread中产生index的顺序。每次从queue中拿数据时都检查其id的合法性,即顺序一致且递增,如果是该数据是乱序的,先缓存起来,再从queue中拿下一个,直到获取有合法id的数据为止”,_MultiProcessDataLoaderIter的做法便是如此。

其中,send_idx表示这是main_thread中产生的第几个index,rcvd_idx表示main_thread已经成功获取到的第几个index对应的tensor数据,而task_info便是用于缓存在queue中拿到的乱序的数据。具体的逻辑在后续的代码分析中。

(2)_next_data方法
def _next_data(self):
        while True:
            ... ...
            # 1、检查本次要拿的数据是否已经在缓存中
            if len(self._task_info[self._rcvd_idx]) == 2:
                data = self._task_info.pop(self._rcvd_idx)[1]
                return self._process_data(data)

            # 2、数据不在缓存中,调用_get_data从queue中拿数据
            idx, data = self._get_data()

            # 3、检查刚拿的数据是否顺序一致
            if idx != self._rcvd_idx:
                # 不一致则放到缓存中
                self._task_info[idx] += (data,)
            else:
                del self._task_info[idx]
                # 一致则交给_process_data处理
                return self._process_data(data)           

在_next_data中出现的这个判断“if len(self._task_info[self._rcvd_idx]) == 2”,表示的含义就是“_rcvd_idx对应的数据是否已经在缓存中”。之所以可以这么判断,是因为_task_info字典中的数据有两种情况:

  1. { _send_idx : (worker_queue_idx,) }
  2. { _send_idx : (worker_queue_idx, data, ) }

在__init__中可以看到,_task_info刚开始是个空的字典,情况1的赋值操作在_try_put_index方法中:

self._task_info[self._send_idx] = (worker_queue_idx,)

如果_next_data中拿到的对应_rcvd_idx的数据是顺序一致的,则删除_task_info中该项,如果顺序不一致,则将拿到的data添加到_task_info的对应项中:

# 不一致则放到缓存中
self._task_info[idx] += (data,)

因此_task_info[_rcvd_idx]如果有两个item,即“len(self._task_info[self._rcvd_idx]) == 2”,就表示该_rcvd_idx对应的数据已经在缓存_task_info中了。

(3)_get_data和_try_get_data
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
        # Returns a 2-tuple:
        #   (bool: whether successfully get data, any: data if successful else None)
        try:
            data = self._data_queue.get(timeout=timeout)
            return (True, data)
        except Exception as e:
            ... ...
            if isinstance(e, queue.Empty):
                return (False, None)
            
def _get_data(self):
        if self._timeout > 0:
            success, data = self._try_get_data(self._timeout)
            if success:
                return data
            else:
                raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
        elif self._pin_memory:
            while self._pin_memory_thread.is_alive():
                success, data = self._try_get_data()
                if success:
                    return data
            else:
                raise RuntimeError('Pin memory thread exited unexpectedly')
        else:
            while True:
                success, data = self._try_get_data()
                if success:
                    return data

_get_data中主要就是根据用户传入的参数(timeout和pin_memory)选择调用_try_get_data的参数。_try_get_data的主要工作就是从_data_queue中取数据然后返回出去,返回的数据有两种状态(True, data)和(False, None)。

(4)_try_put_index和_process_data
def _try_put_index(self):
        try:
        # 1、调用sampler获取index
            index = self._next_index()
        ... ...
        for _ in range(self._num_workers):  # find the next active worker, if any
            worker_queue_idx = next(self._worker_queue_idx_cycle)
            if self._workers_status[worker_queue_idx]:
                break

        # 2、将获得和index和send_idx打包送到对应的_index_queue中
        self._index_queues[worker_queue_idx].put((self._send_idx, index))
        # 3、更新用于保证数据顺序一致性的成员
        self._task_info[self._send_idx] = (worker_queue_idx,)
        self._send_idx += 1

def _process_data(self, data):
        self._rcvd_idx += 1
        self._try_put_index()
        ... ...
        return data

_process_data的逻辑比较简单,给_rcvd_idx加一,然后调用_try_put_index,而_try_put_index的核心职责已经在标注在上述代码注释中,其中第3步与前面通过判断len(self._task_info[self._rcvd_idx])是否等于2的操作相对应。

至此,_MultiProcessDataLoaderIter就介绍完毕了。遗憾的是,为了抓住主体结构,上述贴的代码中去除了很多其他判断逻辑,这些逻辑对于多进程/线程的运行鲁棒性具有重要的意义。

  • 6
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值