概要
深度学习中DataLoader是一个绕不开的问题,在编写完善高效的数据加载代码时,对Dataloader中数据集和采样器的自定义实现是一项必备技能。
这篇文章记录了我对DataLoader的学习与代码理解
整体架构流程
首先,为了搞明白Dataloader是什么,我们先搞清楚他的输入输出,以及使用方式:
@ 创建dataloader对象
data_loader = self.build_train_loader(cfg)
@ 使用方式
Class Trainer():
@ 首先将dataloader构建成迭代器
self.data_loader_iter = iter(data_loader)
def run_step():
@ next函数进行采样
data = next(self._data_loader_iter)
self.optimizer.zero_grad()
loss_dict = self.model(data)
losses.backward()
self.optimizer.step()
@classmethod
def build_train_loader(cls, cfg):
return build_reid_train_loader(cfg, combineall=cfg.DATASETS.COMBINEALL)
@ 定义dataloader函数
@configurable(from_config=_train_loader_from_config)
def build_reid_train_loader(
train_set, *, sampler=None, total_batch_size, num_workers=0,
):
mini_batch_size = total_batch_size // comm.get_world_size()
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, mini_batch_size, True)
train_loader = DataLoaderX(
comm.get_local_rank(),
dataset=train_set,
num_workers=num_workers,
# sampler = sampler,
batch_sampler=batch_sampler,
collate_fn=fast_batch_collator,
pin_memory=True,
)
return train_loader
如上图,构建dataloader主要需要dataset, sampler, 以及batch sampler等参数,其中数据集需要定义__get_item__函数,sampler为采样器,生成采样数据的index,此处一般为生成器,batch sampler的作用即将sampler生成的index进行收集,得到batchsize个index构成下表列表返回。
此三者的实现方式分别如下:
@ 构建数据集,通过__getitem__()函数实现通过下标索引数据。 class CommDataset(Dataset): """Image Person ReID Dataset""" def __init__(self, img_items, transform=None, relabel=True): pass def __len__(self): return len(self.img_items) def __getitem__(self, index): img_item = self.img_items[index] img_path = img_item[0] # 裁剪图像位置 pid = img_item[1] # 目标编号 camid = img_item[2] # 相机编号 img = read_image(img_path) if self.transform is not None: img = self.transform(img) if self.relabel: pid = self.pid_dict[pid] camid = self.cam_dict[camid] return { "images": img, "targets": pid, "camids": camid, "img_paths": img_path, }
@ 采样器的实例化 data_sampler = samplers.InferenceSampler(test_set.img_items, mini_batch_size, 16) @ 采样器类构造 class InferenceSampler(Sampler): def __init__(self, data_source: str, mini_batch_size: int, num_instances: int, seed: Optional[int] = None): pass def __iter__(self): start = self._rank yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) def _infinite_indices(self): while condition: # 两种生成方式 @ 1:逐个生成单个的index yield from reorder_index(batch_indices, self._world_size) @ 2:直接生成batch_indices yield reorder_index(batch_indices, self._world_size) def __len__(self): return len(self.imgids)
@ batch Sampler的构造的使用 batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False) @ batch Sampler的构造 class BatchSampler(Sampler[List[int]]): def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None: pass self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last def __iter__(self) -> Iterator[List[int]]: @ 下面是自主修改的,对应sampler输出list时,在此处就不必进行batch的构建,直接返回 for idx in self.sampler: yield idx if len(idx) > 0 and not self.drop_last: yield idx @ 下面是原始部分,通过逐个收集sampler生成的index,构建完整的batch # batch = [] # for idx in self.sampler: # batch.append(idx) # if len(batch) == self.batch_size: # yield batch # batch = [] # if len(batch) > 0 and not self.drop_last: # yield batch def __len__(self) -> int: if self.drop_last: return len(self.sampler) // self.batch_size # type: ignore[arg-type] else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]
构造完dataloader,来看看他内部采样器生成的index是如何去dataset中进行数据采集的:
class DataLoader(Generic[T_co]): def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Union[Sampler, Iterable, None] = None, batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,...): pass @ dataloader使用时的iter和next会跳转到该方法 def __iter__(self) -> '_BaseDataLoaderIter': return self._get_iterator() def _get_iterator(self) -> '_BaseDataLoaderIter': if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_SingleProcessDataLoaderIter, self).__init__(loader) self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last) def _next_data(self): index = self._next_index() # may raise StopIteration data = self._dataset_fetcher.fetch(index) # may raise StopIteration if self._pin_memory: data = _utils.pin_memory.pin_memory(data) return data class _DatasetKind(object): @staticmethod def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last) @ 这个就是与dataset交互的过程,从dataset中逐个采集样本 class _MapDatasetFetcher(_BaseDatasetFetcher): def __init__(self, dataset, auto_collation, collate_fn, drop_last): super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last) def fetch(self, possibly_batched_index): if self.auto_collation: data = [self.dataset[idx] for idx in possibly_batched_index] else: data = self.dataset[possibly_batched_index] return self.collate_fn(data)
至此,dataloader构建+通过sampler采样生成index+Fetcher获得数据,数据加载完结.
转载请引用