由于dataset中实现了各种transform方法,在每个epoch中通过DataLoader方法来实现。
也就是在每一个epoch都会执行相应的transform方法,对于RandomHorizontalFlip这种增广的方法是有一定的概率执行的。即每个epoch的同一张图片可能会有不同的transform实现。
dataset的实现子类一定要实现的方法有__init__,len,__getitem__这三个方法,而__getitem__方法会根据索引去对每张图片进行操作(包括transform操作)。在dataloader中有引入dataset的实例作为参数。
以下方法来自dataloader中,_get_iterator返回迭代器,里边调用_SingleProcessDataLoaderIter方法,然后调用_MapDatasetFetcher方法。在_MapDatasetFetcher()类当中,在这个类里面实现了具体的数据读取,具体代码如下。代码中调用了dataset,_next_index()是通过sample得到相应的一个batch_size的索引的,通过输入一个索引idx返回一个data,通过一系列的data拼接成一个list。
也就是对于dataloader的实例train_iter,通过for X,y in train_iter,也是一次取batch_size个数据。
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
同时,dataset中也可实现collate_fn方法,对数据进行组织。
class _DatasetKind(object):
Map = 0
Iterable = 1
@staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
_next_index()是通过sample得到相应的一个batch_size的索引的,然后根据这些索引通过_dataset_fetcher.fetch(index)得到相应的数据
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
参考链接https://blog.csdn.net/qq_37388085/article/details/102663166