paddle数据加载源码分析

刚整理完config配置文件,这里顺便把数据加载整理一下,方便以后查阅。

paddle数据加载源码分析

源码分析

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

class MyIterable():
    def __init__(self):
        self.data = [1, 2, 3, 4]
    def __iter__(self): # 返回一个迭代器
        return MyIterator(self.data)

    def __getitem__(self, idx): # 可以用[]对数组进行索引
       return self.data[idx]


class MyIterator():
    def __init__(self, data):
        self.data = data
        self.counter = 0

    def __iter__(self):
        return self 

    def __next__(self):
        if self.counter >= len(self.data):
            raise StopIteration()
        data = self.data[self.counter]
        self.counter += 1
        return data

my_iterable = MyIterable()

for d in my_iterable:
    print(d)

print(my_iterable[2])

输出:
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

Ddatset返回一张图
DataLoader返回一个batch的图

  • batch_size (int|None) - 每mini-batch中样本个数,为 batch_sampler 的替代参数,若 batch_sampler 未设置,会根据 batch_size shuffle drop_last 创建一个 paddle.io.BatchSampler 。默认值为1。

  • batch_sampler 设置从数据集中取数据的方式。默认值为None。

  • collate_fn (callable) - 通过此参数指定如果将样本列表组合为mini-batch数据,当 collate_fn 为None时,默认为将样本个字段在第0维上堆叠(同 np.stack(…, axis=0) )为mini-batch的数据。默认值为None。

在这里插入图片描述

class DataLoader(object):
    def __init__(self, dataset, batch_sampler=None, batch_size=1, collate_fn=None, ...):
        if batch_sampler is not None:
            self.batch_sampler = batch_sampler
        ...
        else:
			...
            if isinstance(dataset, IterableDataset):
                self.batch_sampler = _InfiniteIterableSampler(dataset,
                                                              batch_size)
            else:
                self.batch_sampler = BatchSampler(
                    dataset=dataset,
                    batch_size=batch_size,
                    shuffle=shuffle,
                    drop_last=drop_last)

    def __iter__(self):
        if self.num_workers == 0:
            return _DataLoaderIterSingleProcess(self) # 单进程取数据
        else:
            return _DataLoaderIterMultiProcess(self) # 多进程取数据

    def __call__(self):
        return self.__iter__()
    
    def __len__(self):
    	...

在这里插入图片描述

class _MapDatasetFetcher(_DatasetFetcher):
    def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch,
                                                 collate_fn, drop_last)

    def fetch(self, batch_indices, done_event=None):
        if self.auto_collate_batch:
            data = []
            for idx in batch_indices:
                if done_event is None or not done_event.is_set():
                    data.append(self.dataset[idx]) # 关键代码,将DataLoader和Dataset连接起来
                else:
                    return None

            global _WARNING_TO_LOG
            if not isinstance(data[0], (Sequence, Mapping)) \
                    and _WARNING_TO_LOG:
                self._log_warning()
                _WARNING_TO_LOG = False
        else:
            data = self.dataset[batch_indices]

        if self.collate_fn:
            data = self.collate_fn(data)
        return data

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • num_workers (int) - 用于加载数据的子进程个数,若为0即为不开启子进程,在主进程中进行数据加载。默认值为0。

参考资料

课件:自监督ViT算法:BeiT和MAE

DataLoader

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值