pytorch数据载入Dataloader类

pytorch数据载入Dataloader类

  pytorch的数据导入一般均由Dataloader方法实现,但Dataloader含1000多行代码,阅读起来让人望而却步,pytorch官方文档对该部分所实现的功能有大致介绍,能够帮助你建立关于Dataloader功能的一个框架,文档为全英文,但值得阅读。
pytorch官方文档Dataloader

  涉及迭代器与生成器的相关知识

  下述代码为使用DataLoader载入数据的一般形式及DataLoader类的输入参数。

    train_data_loader = torch.utils.data.DataLoader(train_data_set,
                                                    batch_size=batch_size,
                                                    shuffle=True,
                                                    num_workers=nw,
                                                    collate_fn=train_data_set.collate_fn)

    def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                 shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
                 batch_sampler: Optional[Sampler[Sequence[int]]] = None,
                 num_workers: int = 0, collate_fn: _collate_fn_t = None,
                 pin_memory: bool = False, drop_last: bool = False,
                 timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
                 multiprocessing_context=None, generator=None,
                 *, prefetch_factor: int = 2,
                 persistent_workers: bool = False):

  在DataLoader类的众多参数中,可以看到有两种sampler:sampler和batch_sampler,都默认为None。sampler的作用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,按batch_size的大小划分为batch组的index。

  sampler和batch_sampler均由其对应的类来实现,两个类的定义中均含有def __iter__(self):,表明其为迭代器(迭代器的相关知识还不太了解,有错请指出),在调试过程中两个类的返回值均为地址值,未看到图片索引列表,如下图:

sampler和batch_sampler返回值
  其中,sampler和batch_sampler的元素个数分别为5717、1430,此时batch_size设置为4,5717 / 4 = 1429...1,证明batch_sampler是将sampler生成的indices,按照batch_size的大小划分为batch组的index。

  下面这段代码是class BatchSampler(Sampler[List[int]]):类中对batch划分的实现,通过yeild实现对batch的返回。yeils与生成器self.sample中存放图片的索引。

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

  Dataloader对数据的处理主要就通过sampler和batch_sampler实现。

参考资料:
Pytorch的DataLoader, DataSet, Sampler之间的关系
Pytorch Sampler详解
Python迭代器

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值