Pytorch Dataloder之num_workers(上篇:多进程加载器)

PyTorch的_MultiProcessingDataLoaderIter是DataLoader在多进程模式下的迭代器。它的主要作用是通过多进程并行地预取数据批次,以提高数据加载的效率。

整体流程了解:_MultiProcessingDataLoaderIter是多生产者、一个消费者的模型。每个子进程都是生产者,用于读取数据,将读取好的数据放入到共享队列self._data_queue中。消费者会遍历self._data_queue进行数据读取,如果(消费者消费数据的索引==生产者生产数据的索引)则返回数据。否则,消费者会将从self._data_queue中取到的数据 存到字典self._task_info中,以便后续消费者读到该索引时,直接返回数据。然后消费者会继续读取共享队列self._data_queue,直至索引相等。

便于理解的关键变量:

  1. 生产者生产数据的索引 = self._send_idx = 子进程准备的batch的数据索引。

  2. 消费者消费数据的索引 = self._rcvd_idx = 主进程读取的batch的数据索引。

  3. 子进程维护的索引队列 = self._index_queues = 每个子进程根据这个队列从数据集中进行数据读取。

  4. self._index_queues详解:_MultiProcessingDataLoaderIter中,每个子进程会维护一个self._index_queues,里面存放的是(self._send_idx ,一个batch内所有数据的索引)。为了区分每个子进程读取的是第几个batch的数据,所以要加一个self._send_idx标记,为生产者生成数据的索引,用于后续与消费者进行匹配。Dataloader在初始化时,便把self._index_queues处理好了,所以每个子进程不会读取到重复的数据。

  5. self._index_queues样例:prefetch_factor=2,num_workers=3,batch_size=4。有三个子进程,每个子进程会取2个batch数据,共取6个batch的数据。

woker0.index_queues:[(0, [0,1,2,3]), (3, [12,13,14,15])]
woker1.index_queues:[(1, [4,5,6,7]), (4, [16,17,18,19])]
woker2.index_queues:[(2, [8,9,10,11]), (5, [20,21,22,23])]

_MultiProcessingDataLoaderIter的主要初始化代码如下,读取batch的索引变量名是_send_idx,详细看下面代码注释。

class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_MultiProcessingDataLoaderIter, self).__init__(loader)

        self._index_queues = []
        self._workers = []
        for i in range(self._num_workers):    # 1.创建num_workers个子进程
            index_queue = multiprocessing_context.Queue() #2.每个子进程维护一个index_queue
            index_queue.cancel_join_thread()
            w = multiprocessing_context.Process(
                target=_utils.worker._worker_loop,
                args=(self._dataset_kind, self._dataset, index_queue,
                      self._worker_result_queue, self._workers_done_event,
                      self._auto_collation, self._collate_fn, self._drop_last,
                      self._base_seed, self._worker_init_fn, i, self._num_workers,
                      self._persistent_workers))
            w.daemon = True
            w.start()
            self._index_queues.append(index_queue)
            self._workers.append(w)

        self._reset(loader, first_iter=True)    # 3.初始化主/子进程的一些配置,初始化子进程index_queue。
        
    def _reset(self, loader, first_iter=False):
        super()._reset(loader, first_iter)
        self._send_idx = 0  # 子进程中读取batch的索引,子进程读取的是第几个batch的数据。
        self._rcvd_idx = 0  # 主进程中返回batch的索引,主进程返回的是第几个batch的数据。
        self._task_info = {}
        self._tasks_outstanding = 0
        self._workers_status = [True for i in range(self._num_workers)]
        if not first_iter:
            for idx in range(self._num_workers):
                self._index_queues[idx].put(_utils.worker._ResumeIteration())
            resume_iteration_cnt = self._num_workers
            while resume_iteration_cnt > 0:
                return_idx, return_data = self._get_data()
                if isinstance(return_idx, _utils.worker._ResumeIteration):
                    assert return_data is None
                    resume_iteration_cnt -= 1

        for _ in range(self._prefetch_factor * self._num_workers):
            self._try_put_index()
        # 4.上述两行主要作用:初始化/维护每个子进程的index_queue,详细看下面样例。
        # self._prefetch_factor是子进程预先加载的样本数,默认为2。
        '''
        样例:prefetch_factor=2,num_workers=3,batch_size=4
        有三个子进程,每个子进程会取2个batch数据,共取6个batch的数据。
        woker0.index_queue:[(0, [0,1,2,3]), (3, [12,13,14,15])]
        woker1.index_queue:[(1, [4,5,6,7]), (4, [16,17,18,19])]
        woker2.index_queue:[(2, [8,9,10,11]), (5, [20,21,22,23])]
        '''
        
    def _try_put_index(self):

        assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
        try:
            index = self._next_index()    # 一个batch内所有数据的索引
        except StopIteration:
            return
        for _ in range(self._num_workers):  # 遍历总进程数
            worker_queue_idx = next(self._worker_queue_idx_cycle) # 循环所有的子进程
            if self._workers_status[worker_queue_idx]: # 如果当前子进程可用,则使用当前子进程,一般都是可用的
                break
        else:
            return
        self._index_queues[worker_queue_idx].put((self._send_idx, index))

        # 将子进程中(读取batch的索引,batch内数据的索引)绑定,维护到子进程的index_queue中。
        self._task_info[self._send_idx] = (worker_queue_idx,)
        self._tasks_outstanding += 1
        self._send_idx += 1    # 读取batch的索引+1

文章推荐阅读:https://www.zhihu.com/question/360391842/answer/931500421

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

gy-7

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值