dataloader对dataset中getitem方法调用(每一个epoch中通过transform实现数据增广

由于dataset中实现了各种transform方法,在每个epoch中通过DataLoader方法来实现。
也就是在每一个epoch都会执行相应的transform方法,对于RandomHorizontalFlip这种增广的方法是有一定的概率执行的。即每个epoch的同一张图片可能会有不同的transform实现。

dataset的实现子类一定要实现的方法有__init__,len,__getitem__这三个方法,而__getitem__方法会根据索引去对每张图片进行操作(包括transform操作)。在dataloader中有引入dataset的实例作为参数。

以下方法来自dataloader中,_get_iterator返回迭代器,里边调用_SingleProcessDataLoaderIter方法,然后调用_MapDatasetFetcher方法。在_MapDatasetFetcher()类当中,在这个类里面实现了具体的数据读取,具体代码如下。代码中调用了dataset,_next_index()是通过sample得到相应的一个batch_size的索引的,通过输入一个索引idx返回一个data,通过一系列的data拼接成一个list。
也就是对于dataloader的实例train_iter,通过for X,y in train_iter,也是一次取batch_size个数据。

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

同时,dataset中也可实现collate_fn方法,对数据进行组织。

class _DatasetKind(object):
    Map = 0
    Iterable = 1

    @staticmethod
    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
        if kind == _DatasetKind.Map:
            return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
        else:
            return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

_next_index()是通过sample得到相应的一个batch_size的索引的,然后根据这些索引通过_dataset_fetcher.fetch(index)得到相应的数据

    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)

参考链接https://blog.csdn.net/qq_37388085/article/details/102663166

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值