Pytorch Dataloader之sampler


先看:Pytorch Dataloader入门-CSDN博客

dataloader除了用的比较多的一些基本功能外,本文还想讲一下dataloader的其余所有参数详解,详细研究一下dataloder是怎么工作的,帮助大家更加全面的认识dataloader。那就按照dataloder参数的顺序来详细分析一下,本文分析dataloder参数中的sampler。


sampler (Sampler or Iterable, 可选):

从数据集中采样的策略。可以自定义,任何实现了 __len__方法的 Iterable;也可以使用dataloader中提供的默认的sampler。dataloader中提供了两个默认的sampler:RandomSampler(随机采样)和SequentialSampler(顺序采样)。使用那个默认的sampler是由shuffle参数决定的。因此,sampler参数与shuffle参数是冲突的,两个只能指定一个。

这里的主要判断逻辑在:torch/utils/data/dataloader.py#L262-L272

        if sampler is None:  # give default samplers
            if self._dataset_kind == _DatasetKind.Iterable:
                # See NOTE [ Custom Samplers and IterableDataset ]
                sampler = _InfiniteConstantSampler()
            else:  # map-style
                if shuffle:
                    # Cannot statically verify that dataset is Sized
                    # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
                    sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]
                else:
                    sampler = SequentialSampler(dataset)  # type: ignore[arg-type]

如果你自定义了Sampler,Dataloder则采用你定义的Sampler。除此之外,Pytorch Dataloader提供了两种默认的Sampler:

如果shuffle=True,则使用RandomSampler,随机从数据集中进行采样。

class RandomSampler(Sampler[int]):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify :attr:`num_samples` to draw.

    Args:
        data_source (Dataset): dataset to sample from
        replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
        num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
            is supposed to be specified only when `replacement` is ``True``.
        generator (Generator): Generator used in sampling.
    """
    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, generator=None) -> None:
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        self.generator = generator

    @property
    def num_samples(self) -> int:
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        if self.generator is None:
            generator = torch.Generator()
            generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
        else:
            generator = self.generator
        if self.replacement:
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
        else:
            yield from torch.randperm(n, generator=generator).tolist()

    def __len__(self) -> int:
        return self.num_samples
        
'''       
torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
[6, 8, 9, 2, 3, 7, 1, 4, 7, 7]
torch.randperm(n, generator=generator).tolist()
[5, 0, 1, 4, 9, 6, 3, 2, 8, 7]
'''
  • 当 replacement 设置为 True 时,采样是有放回的,即同一个样本可能会被多次采样到,生成的索引可能存在重复。上述代码line48。
  • 当 replacement 设置为 False 时,采样是无放回的,即每个样本只会被采样一次,生成的索引不会有重复。上述代码line50。

  • replacement默认为False。

否则使用SequentialSampler,顺序从数据集中进行采样。

class SequentialSampler(Sampler[int]):
    ***
    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

根据数据的长度,顺序返回数据的索引。下标是0~数据长度-1


嗨,欢迎大家关注我的公众号《CV之路》,一起讨论问题,一起学习进步~。也欢迎大家关注我的GitHub仓库,我出的所有博文教程都是无偿分享的,只求个关注与Star~,多谢大家支持!

GitHub - gy-7/CV-Road (后续教程相关所有代码都会维护到此仓库)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

gy-7

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值