dataloader需要加载数据,为了加速数据读取和处理,使用多进程是一个比较好的解决方法。num_workers则是控制数据加载时,子进程的数量,默认为0。
-
num_workers=0时,表示采用单进程方法加载数据,在主进程中加载数据。
-
num_workers=1时,表示采用多进程方法加载数据,但是只有一个子进程,使用该子进程加载数据。
-
num_workers>1时,表示采用多进程方法加载数据,有num_workers个子进程。
dataloader中关于单进程加载与多进程加载的逻辑代码如下:
1. 在从dataloader中取数据时,先调用dataloader的__iter__方法,__iter__方法中,则会优先调用self._get_iterator()方法,返回一个迭代器的实例化对象。
def __iter__(self) -> '_BaseDataLoaderIter':
# When using a single worker the returned iterator should be
# created everytime to avoid reseting its state
# However, in the case of a multiple workers iterator
# the iterator is only created once in the lifetime of the
# DataLoader object so that workers can be reused
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
2. _get_iterator()方法通过判断self.num_workers参数,返回单进程迭代器_SingleProcessDataLoaderIter还是多进程迭代器_MultiProcessingDataLoaderIter。
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self) # 单进程迭代器
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self) # 多进程迭代器
3. 两种迭代器都继承自基类_BaseDataLoaderIter。两种迭代器实现的功能都是通过__next__()方法,返回一个batch的数据。首先看基类_BaseDataLoaderIter:
3.1 在我们从dataloader中取数据时:for batch in dataloder中:会按照next(iter(dataloder))进行调用。iter(dataloder)返回的是_BaseDataLoaderIter类的实例化,我们称为iterator。next方法实际是作用在_BaseDataLoaderIter类,因此_BaseDataLoaderIter实现了__next__方法,最终返回一个batch的数据。
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
***
def _next_data(self): # 3.由子类实现,由此方法返回最终数据,单进程/多进程实现方法不一样,重点关注。
raise NotImplementedError
def __next__(self) -> Any: # 1.返回最终的一个batch的数据。
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data() # 2.调用子类的_next_data()方法获取数据。
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
"samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
self._num_yielded)
if self._num_workers > 0:
warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
"IterableDataset replica at each worker. Please see "
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
warnings.warn(warn_msg)
return data
***
3.2 先看_SingleProcessDataLoaderIter,单进程迭代器相比多进程迭代器整体逻辑会简单些,本文先讲单进程迭代器,后续会出一篇单独讲多进程迭代器。主要看_next_data()方法,用于返回一个batch的数据。
class DataLoader(Generic[T_co]):
@property
def _auto_collation(self): # 6.有batch_sampler则返回True
return self.batch_sampler is not None
@property
def _index_sampler(self):
if self._auto_collation: # 5.有batch_sampler则返回batch_sampler
return self.batch_sampler
else:
return self.sampler
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
***
self._index_sampler = loader._index_sampler # 4.通过调用dataloader的_index_sampler方法,获取batch_sampler
self._sampler_iter = iter(self._index_sampler) # 3.self._sampler_iter是一个batch_sampler迭代器
def _next_index(self): # 2.self._sampler_iter是一个batch_sampler迭代器,返回一个batch数据的索引。
return next(self._sampler_iter)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # 1.调用基类的_next_index()方法,获取一个batch数据的索引。
data = self._dataset_fetcher.fetch(index) # 2.利用DatasetFetcher,调用fetch()方法,将索引转换成一个batch的数据。
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
3.3 这里扩展讲解一下DatasetFetcher,代码不多,看注释,主要是_MapDatasetFetcher的fetch()方法。
class DataLoader(Generic[T_co]):
@property
def _auto_collation(self): # 6.有batch_sampler则返回True
return self.batch_sampler is not None
@property
def _index_sampler(self):
if self._auto_collation: # 5.有batch_sampler则返回batch_sampler
return self.batch_sampler
else:
return self.sampler
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
***
self._index_sampler = loader._index_sampler # 4.通过调用dataloader的_index_sampler方法,获取batch_sampler
self._sampler_iter = iter(self._index_sampler) # 3.self._sampler_iter是一个batch_sampler迭代器
def _next_index(self): # 2.self._sampler_iter是一个batch_sampler迭代器,返回一个batch数据的索引。
return next(self._sampler_iter)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # 1.调用基类的_next_index()方法,获取一个batch数据的索引。
data = self._dataset_fetcher.fetch(index) # 2.利用DatasetFetcher,调用fetch()方法,将索引转换成一个batch的数据。
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
嗨,欢迎大家关注我的公众号《CV之路》,一起讨论问题,一起学习进步~。也欢迎大家关注我的GitHub仓库,我出的所有博文教程都是无偿分享的,只求个关注与Star~,多谢大家支持!
GitHub - gy-7/CV-Road (后续教程相关所有代码都会维护到此仓库)