DataLoader完整解析与代码结构梳理:dataset+sampler+fetcher

概要

深度学习中DataLoader是一个绕不开的问题,在编写完善高效的数据加载代码时,对Dataloader数据集和采样器的自定义实现是一项必备技能。

这篇文章记录了我对DataLoader的学习与代码理解

整体架构流程

首先,为了搞明白Dataloader是什么,我们先搞清楚他的输入输出,以及使用方式:

@ 创建dataloader对象
data_loader = self.build_train_loader(cfg)
@ 使用方式
Class Trainer():
        @ 首先将dataloader构建成迭代器
        self.data_loader_iter = iter(data_loader)
    def run_step():
        @ next函数进行采样
        data = next(self._data_loader_iter)
        self.optimizer.zero_grad()
        loss_dict = self.model(data)
        losses.backward()
        self.optimizer.step()



@classmethod
def build_train_loader(cls, cfg):
    return build_reid_train_loader(cfg, combineall=cfg.DATASETS.COMBINEALL)


@ 定义dataloader函数
@configurable(from_config=_train_loader_from_config)
def build_reid_train_loader(
        train_set, *, sampler=None, total_batch_size, num_workers=0,
):
    mini_batch_size = total_batch_size // comm.get_world_size()

    batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, mini_batch_size, True)

    train_loader = DataLoaderX(
        comm.get_local_rank(),
        dataset=train_set,
        num_workers=num_workers,
        # sampler = sampler,
        batch_sampler=batch_sampler,
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )

    return train_loader

如上图,构建dataloader主要需要dataset, sampler, 以及batch sampler等参数,其中数据集需要定义__get_item__函数,sampler为采样器,生成采样数据的index,此处一般为生成器,batch sampler的作用即将sampler生成的index进行收集,得到batchsize个index构成下表列表返回。

此三者的实现方式分别如下:

@ 构建数据集,通过__getitem__()函数实现通过下标索引数据。
class CommDataset(Dataset):
    """Image Person ReID Dataset"""

    def __init__(self, img_items, transform=None, relabel=True):
        pass

    def __len__(self):
        return len(self.img_items)

    def __getitem__(self, index):
        img_item = self.img_items[index]
        img_path = img_item[0] # 裁剪图像位置
        pid = img_item[1]  # 目标编号
        camid = img_item[2]  # 相机编号
        img = read_image(img_path)
        if self.transform is not None: img = self.transform(img)
        if self.relabel:
            pid = self.pid_dict[pid]
            camid = self.cam_dict[camid]
        return {
            "images": img,
            "targets": pid,
            "camids": camid,
            "img_paths": img_path,
        }
@ 采样器的实例化
data_sampler = samplers.InferenceSampler(test_set.img_items, mini_batch_size, 16)

@ 采样器类构造
class InferenceSampler(Sampler):
    def __init__(self, data_source: str, mini_batch_size: int, num_instances: int, seed: Optional[int] = None):
        pass

    def __iter__(self):
        start = self._rank
        yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)

    def _infinite_indices(self):
        while condition:
            # 两种生成方式
            @ 1:逐个生成单个的index
            yield from reorder_index(batch_indices, self._world_size)
            @ 2:直接生成batch_indices
            yield reorder_index(batch_indices, self._world_size)


    def __len__(self):
        return len(self.imgids)
@ batch Sampler的构造的使用
batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False)

@ batch Sampler的构造
class BatchSampler(Sampler[List[int]]):
    def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
        pass
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self) -> Iterator[List[int]]:
        @ 下面是自主修改的,对应sampler输出list时,在此处就不必进行batch的构建,直接返回
        for idx in self.sampler:
            yield idx
        if len(idx) > 0 and not self.drop_last:
            yield idx
        @ 下面是原始部分,通过逐个收集sampler生成的index,构建完整的batch
        # batch = []
        # for idx in self.sampler:
        #     batch.append(idx)
        #     if len(batch) == self.batch_size:
        #         yield batch
        #         batch = []
        # if len(batch) > 0 and not self.drop_last:
        #     yield batch

    def __len__(self) -> int:
        if self.drop_last:
            return len(self.sampler) // self.batch_size  # type: ignore[arg-type]
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size  # type: ignore[arg-type]
构造完dataloader,来看看他内部采样器生成的index是如何去dataset中进行数据采集的:
class DataLoader(Generic[T_co]):
    def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                 shuffle: bool = False, sampler: Union[Sampler, Iterable, None] = None,
                 batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
                 num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,...):
        pass

    @ dataloader使用时的iter和next会跳转到该方法
    def __iter__(self) -> '_BaseDataLoaderIter':
            return self._get_iterator()

    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)


    
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

class _DatasetKind(object):
    @staticmethod
    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
        return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)

@ 这个就是与dataset交互的过程,从dataset中逐个采集样本
class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

至此,dataloader构建+通过sampler采样生成index+Fetcher获得数据,数据加载完结.

转载请引用

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值