PyTorch的_MultiProcessingDataLoaderIter是DataLoader在多进程模式下的迭代器。它的主要作用是通过多进程并行地预取数据批次,以提高数据加载的效率。
整体流程了解:_MultiProcessingDataLoaderIter是多生产者、一个消费者的模型。每个子进程都是生产者,用于读取数据,将读取好的数据放入到共享队列self._data_queue中。消费者会遍历self._data_queue进行数据读取,如果(消费者消费数据的索引==生产者生产数据的索引)则返回数据。否则,消费者会将从self._data_queue中取到的数据 存到字典self._task_info中,以便后续消费者读到该索引时,直接返回数据。然后消费者会继续读取共享队列self._data_queue,直至索引相等。
便于理解的关键变量:
-
生产者生产数据的索引 = self._send_idx = 子进程准备的batch的数据索引。
-
消费者消费数据的索引 = self._rcvd_idx = 主进程读取的batch的数据索引。
-
子进程维护的索引队列 = self._index_queues = 每个子进程根据这个队列从数据集中进行数据读取。
-
self._index_queues详解:_MultiProcessingDataLoaderIter中,每个子进程会维护一个self._index_queues,里面存放的是(self._send_idx ,一个batch内所有数据的索引)。为了区分每个子进程读取的是第几个batch的数据,所以要加一个self._send_idx标记,为生产者生成数据的索引,用于后续与消费者进行匹配。Dataloader在初始化时,便把self._index_queues处理好了,所以每个子进程不会读取到重复的数据。
-
self._index_queues样例:prefetch_factor=2,num_workers=3,batch_size=4。有三个子进程,每个子进程会取2个batch数据,共取6个batch的数据。
woker0.index_queues:[(0, [0,1,2,3]), (3, [12,13,14,15])]
woker1.index_queues:[(1, [4,5,6,7]), (4, [16,17,18,19])]
woker2.index_queues:[(2, [8,9,10,11]), (5, [20,21,22,23])]
_MultiProcessingDataLoaderIter的主要初始化代码如下,读取batch的索引变量名是_send_idx,详细看下面代码注释。
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
self._index_queues = []
self._workers = []
for i in range(self._num_workers): # 1.创建num_workers个子进程
index_queue = multiprocessing_context.Queue() #2.每个子进程维护一个index_queue
index_queue.cancel_join_thread()
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop,
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed, self._worker_init_fn, i, self._num_workers,
self._persistent_workers))
w.daemon = True
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)
self._reset(loader, first_iter=True) # 3.初始化主/子进程的一些配置,初始化子进程index_queue。
def _reset(self, loader, first_iter=False):
super()._reset(loader, first_iter)
self._send_idx = 0 # 子进程中读取batch的索引,子进程读取的是第几个batch的数据。
self._rcvd_idx = 0 # 主进程中返回batch的索引,主进程返回的是第几个batch的数据。
self._task_info = {}
self._tasks_outstanding = 0
self._workers_status = [True for i in range(self._num_workers)]
if not first_iter:
for idx in range(self._num_workers):
self._index_queues[idx].put(_utils.worker._ResumeIteration())
resume_iteration_cnt = self._num_workers
while resume_iteration_cnt > 0:
return_idx, return_data = self._get_data()
if isinstance(return_idx, _utils.worker._ResumeIteration):
assert return_data is None
resume_iteration_cnt -= 1
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index()
# 4.上述两行主要作用:初始化/维护每个子进程的index_queue,详细看下面样例。
# self._prefetch_factor是子进程预先加载的样本数,默认为2。
'''
样例:prefetch_factor=2,num_workers=3,batch_size=4
有三个子进程,每个子进程会取2个batch数据,共取6个batch的数据。
woker0.index_queue:[(0, [0,1,2,3]), (3, [12,13,14,15])]
woker1.index_queue:[(1, [4,5,6,7]), (4, [16,17,18,19])]
woker2.index_queue:[(2, [8,9,10,11]), (5, [20,21,22,23])]
'''
def _try_put_index(self):
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
try:
index = self._next_index() # 一个batch内所有数据的索引
except StopIteration:
return
for _ in range(self._num_workers): # 遍历总进程数
worker_queue_idx = next(self._worker_queue_idx_cycle) # 循环所有的子进程
if self._workers_status[worker_queue_idx]: # 如果当前子进程可用,则使用当前子进程,一般都是可用的
break
else:
return
self._index_queues[worker_queue_idx].put((self._send_idx, index))
# 将子进程中(读取batch的索引,batch内数据的索引)绑定,维护到子进程的index_queue中。
self._task_info[self._send_idx] = (worker_queue_idx,)
self._tasks_outstanding += 1
self._send_idx += 1 # 读取batch的索引+1
文章推荐阅读:https://www.zhihu.com/question/360391842/answer/931500421