PyTorch | 如何控制dataloader的随机shuffle


前言 在使用PyTorch进行训练或者测试的过程中,一般来说dataloader在每个epoch返回的样本顺序是不一样的,但在某些特殊情况中,我们可能希望dataloader按照固定的顺序进行多个epoch。本文作者给出了一个简单方便的实现思路,附详解代码。

作者:魏鸿鑫@知乎

编辑:CV技术指南

原文:https://zhuanlan.zhihu.com/p/515697362

问题背景


在使用PyTorch进行训练或者测试的过程中,一般来说dataloader在每个epoch返回的样本顺序是不一样的,但在某些特殊情况中,我们可能希望dataloader按照固定的顺序进行多个epoch, 或者说,在一个epoch中按照固定的顺序进行多次的样本循环iteration。

现有Sampler


默认的 RandomSampler 在生成iteration的时候会重新做一次random shuffle,所以无法直接实现这个需求。

def __iter__(self) -> Iterator[int]:
      n = len(self.data_source)
      if self.generator is None:
          seed = int(torch.empty((), dtype=torch.int64).random_().item())
          generator = torch.Generator()
          generator.manual_seed(seed)
      else:
          generator = self.generator

      if self.replacement:
          for _ in range(self.num_samples // 32):
              yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
          yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
      else:
          for _ in range(self.num_samples // n):
              yield from torch.randperm(n, generator=generator).tolist()
          yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]

上面的代码是RandomSampler中最重要的__iter__函数,我们可以看到每次调用这个函数或者新的iter时会得到一个新的随机顺序的iteration。

再看看另一个常用的sampler,也就是 SequentialSampler。我们在test的时候经常会设置shuffle=false,这时候就相当于使用了SequentialSampler:

class SequentialSampler(Sampler[int]):
    r"""Samples elements sequentially, always in the same order.

    Args:
        data_source (Dataset): dataset to sample from
    """
    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

在代码中可以看到,这个sampler就是简单地创造并返回一个range序列,无法对其进行shuffle操作。

解决方案


结合上面两个现有的sampler,我们可以简单地自定义一个新的sampler来实现我们的需求。也就是说,我们希望能够手动控制何时进行shuffle操作,在没有shuffle时我们希望sampler按照前面的顺序返回iteration。

下面是我的实现:

class MySequentialSampler(SequentialSampler):
    def __init__(self, data_source, num_data=None):
        self.data_source = data_source
        self.my_list = list(range(len(self.data_source)))
        random.shuffle(self.my_list)
        if num_data is None:
            self.num_data = len(self.my_list)
        else:
            self.num_data = num_data
            self.my_list = self.my_list[:num_data]

    def __iter__(self):
        return iter(self.my_list)

    def __len__(self):
        return self.num_data

    def shuffle(self):
        self.my_list = list(range(len(self.data_source)))
        random.shuffle(self.my_list)
        self.my_list = self.my_list[:self.num_data]

这个实现非常简单而且使用方便。在默认情况下基本等同于SequentialSampler (去掉init函数中的shuffle即完全一致)。当我们需要重新shuffle序列的时候,只需要调用shuffle函数即可,比如:dataloader.sampler.shuffle(). 通过这个自定义sampler,我们就可以实现在指定的时候进行shuffle操作,而不是固定在每个iteration结束时进行shuffle。


ps: 理论上也可以直接通过对dataset进行shuffle,但这样操作的缺点是会改变对应的index,另外一般我们在train或者test函数中不会获取到dataset,而只能从loader进行操作(dataloader.dataset一般只能获取到length)。因此,修改sampler可以说是对原训练方法流程最少的方式。

猜您喜欢:

14ed2a08b67fff5c023b2ef33f790e3b.png 戳我,查看GAN的系列专辑~!

一顿午饭外卖,成为CV视觉前沿弄潮儿!

CVPR 2022 | 25+方向、最新50篇GAN论文

 ICCV 2021 | 35个主题GAN论文汇总

超110篇!CVPR 2021最全GAN论文梳理

超100篇!CVPR 2020最全GAN论文梳理

拆解组新的GAN:解耦表征MixNMatch

StarGAN第2版:多域多样性图像生成

附下载 | 《可解释的机器学习》中文版

附下载 |《TensorFlow 2.0 深度学习算法实战》

附下载 |《计算机视觉中的数学方法》分享

《基于深度学习的表面缺陷检测方法综述》

《零样本图像分类综述: 十年进展》

《基于深度神经网络的少样本学习综述》

64e2c697b1927e287e6b3972338258f5.png

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
dataloaderPyTorch中用于加载数据的工具,它可以方便地对数据进行预处理、转换和加载。在使用dataloader加载数据时,我们可以通过设置随机种子来控制数据的洗牌(shuffle)过程。 为什么需要设置随机种子呢?在训练深度学习模型时,通常需要将数据集随机打乱来增加模型的泛化性能,减少模型对数据的依赖性。随机洗牌可以打破输入数据的顺序性,使模型更好地学习数据的特征和规律。但是,在同样的训练过程中,我们希望每次运行时的随机结果都是一致的,这样才能确保模型的可复现性,方便我们进行调试和比较实验结果。 因此,我们可以通过设置随机种子来控制dataloader加载数据时的洗牌过程。在使用dataloader加载数据时,可以通过设置参数shuffle=True来开启洗牌功能。同时,我们可以通过设置参数torch.manual_seed(seed)来设置随机种子的值。这样,每次运行时,dataloader在洗牌数据时都会使用相同的随机种子,从而保证了每次的洗牌结果都是一致的。 例如,我们可以使用以下代码来设置随机种子为10,并开启洗牌功能: torch.manual_seed(10) dataloader = DataLoader(dataset, shuffle=True) 这样,每次运行时,dataloader都会使用相同的随机种子10进行数据洗牌,从而保证了每次的洗牌结果都是一致的。这对于实验结果的比较和模型的可复现性非常重要。 总之,通过设置随机种子,我们可以在dataloader加载数据时控制洗牌过程,确保每次洗牌结果的一致性,从而提高模型的可复现性和实验的可比性。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值