DataLoader 的使用方法:
- 调用 dataloader. __iter__ 获取迭代器
- 调用 迭代器的 __next__ 获取下一个 batch
首先 dataloader 可以设置是否 shuffle
那么只要看 shuffle 参数对这个过程有什么影响即可
class DataLoader(Generic[T_co]):
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0, collate_fn: _collate_fn_t = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False):
注释:
shuffle (bool, optional): set to
True
to have the data reshuffled at every epoch (default:False
).
可以看到数据会在每个 epoch 中被 reshuffle。
其实现中,直接相关的代码有:
if shuffle:
sampler = RandomSampler(dataset, generator=generator) # type: ignore
else:
sampler = SequentialSampler(dataset)
# ......
self.sampler = sampler
若设置了 shuffle 为 True 则将采样实例化为随机采样。
当调用 dataloader 的 iter() :
def __iter__(self) -> '_BaseDataLoaderIter':
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
会调用 self._get_iterator() 方法,并返回生成的 iterator,用户通过调用该 iterator 的 next() 方法获取每个 batch。
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
该方法先判断是否多线程,然后选择生成对应的 DataLoaderIter,以单线程为例:
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
没有看到该类中的 __next__实现,看其父类 _BaseDataLoaderIter ,无关代码全删了。
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
self._dataset = loader.dataset
self._index_sampler = loader._index_sampler
self._sampler_iter = iter(self._index_sampler)
self._num_yielded = 0
# 以及其他属性
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data()
self._num_yielded += 1
return data
_BaseDataLoaderIter 调用了 dataloader 的 _index_sampler 方法,获取了一个 sampler,并调用其 __iter__得到一个迭代器,以迭代获取 batch 中元素的下标。
当用户调用 _BaseDataLoaderIter 其子类的 __next__ 时,调用由子类实现的 _next_data() 方法获取 batch,并然给当前的计数 +1。
当该迭代器用完的时候就会调用 self._reset() 重新生成一个新的迭代器,并将计数归零。
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler)
self._num_yielded = 0
self._IterableDataset_len_called = loader._IterableDataset_len_called
_next_data() 在子类 _SingleProcessDataLoaderIter 中重写 :
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
逻辑为调用 self._next_index() 方法获取当前要返回的 batch 中元素,在 dataset 中的下标 index,然后使用 self._dataset_fetcher.fetch(index) 获取对应元素后返回即可。
_next_index 在父类中实现:
def _next_index(self):
return next(self._sampler_iter)
其实就是调用之前得到的迭代器的 __next__ 以获取下一组 index。
那么 sampler 是怎样获取 index 的呢?以下是之前出现过的 RandomSampler 的代码:
class RandomSampler(Sampler[int]):
def __init__(self, data_source: Sized, replacement: bool = False,
num_samples: Optional[int] = None, generator=None) -> None:
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator
def __iter__(self):
n = len(self.data_source)
if self.generator is None:
generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
else:
generator = self.generator
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
yield from torch.randperm(n, generator=self.generator).tolist()
其实只要看一行就够了
yield from torch.randperm(n, generator=self.generator).tolist()
torch.randperm 就是返回元素为从 0 到 n-1 的张量,将其做 tolist 之后再使用 yield from 将其内的元素一个个 yield 出来,所以返回的就是 0 到 n-1的随机数字,且不重复。
但是!DataLoaderIter 里头用的并不单单是这个 sampler,这是由于若使用该 sampler, 一次只能获取一条数据,但是咱们是一个 batch, 一个 batch 有多条数据。
回看_BaseDataLoaderIter 中, loader._index_sampler 调用的是以下方法
@property
def _index_sampler(self):
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
self._auto_collation = True 时 就会选择 self.batch_sampler
@property
def _auto_collation(self):
return self.batch_sampler is not None
self._auto_collation 什么时候等于 True 呢? 当 self.batch_sampler is not None。那 self.batch_sampler 啥时候会为 None 呢?
if batch_size is not None and batch_sampler is None:
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
当 batch_size is None 的时候,好吧跟我们现在的情况没什么关系,_index_sampler 返回的就是 self.batch_sampler。
batch_sampler 在实例化时需要传入 sampler, batch_size, drop_last, 此处的 sampler 就是之前的 RandomSampler。batch_sampler 的 __iter__ 实现为:
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
首先从获取了 sampler 的迭代器,并且逐个获取 sampler 中的内容也就是数据的下标,当获取的数量与 batch size 相同时将会 yield 出去。
好了那么就非常清楚了:
- 调用 dataLoader.__iter__ 得到生成器 dataLoaderIter
- 用户调用 dataLoaderIter 的 __next__ 想要获取下一个 batch
- dataLoaderIter 调用 dataLoader 的 sampler.__next__ 获取一组下标用于从 dataset 获取 data
- dataLoader 的 sampler 是一个 batchSampler,它从 randomSampler 中获取所有打乱的 index
- batchSampler 根据 batch size 将 index 一组一组地交给 dataLoaderIter
- dataLoaderIter 调用 _dataset_fetcher.fetch(index) 获取该 batch 的 data 返回给用户
以上就是当 shuffle = True 时的获取逻辑了。当然,当 shuffle = False 时,使用的就不是 RandomSampler 而是 SequentialSampler 了,那么 index 就是顺序不变的,batch 内的元素也就不变了