先看:Pytorch Dataloader入门-CSDN博客
dataloader除了用的比较多的一些基本功能外,本文还想讲一下dataloader的其余所有参数详解,详细研究一下dataloder是怎么工作的,帮助大家更加全面的认识dataloader。那就按照dataloder参数的顺序来详细分析一下,本文分析dataloder参数中的sampler。
sampler (Sampler or Iterable, 可选):
从数据集中采样的策略。可以自定义,任何实现了 __len__方法的 Iterable;也可以使用dataloader中提供的默认的sampler。dataloader中提供了两个默认的sampler:RandomSampler(随机采样)和SequentialSampler(顺序采样)。使用那个默认的sampler是由shuffle参数决定的。因此,sampler参数与shuffle参数是冲突的,两个只能指定一个。
这里的主要判断逻辑在:torch/utils/data/dataloader.py#L262-L272
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
# Cannot statically verify that dataset is Sized
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
else:
sampler = SequentialSampler(dataset) # type: ignore[arg-type]
如果你自定义了Sampler,Dataloder则采用你定义的Sampler。除此之外,Pytorch Dataloader提供了两种默认的Sampler:
如果shuffle=True,则使用RandomSampler,随机从数据集中进行采样。
class RandomSampler(Sampler[int]):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.
Args:
data_source (Dataset): dataset to sample from
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
is supposed to be specified only when `replacement` is ``True``.
generator (Generator): Generator used in sampling.
"""
data_source: Sized
replacement: bool
def __init__(self, data_source: Sized, replacement: bool = False,
num_samples: Optional[int] = None, generator=None) -> None:
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator
@property
def num_samples(self) -> int:
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
else:
generator = self.generator
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
yield from torch.randperm(n, generator=generator).tolist()
def __len__(self) -> int:
return self.num_samples
'''
torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
[6, 8, 9, 2, 3, 7, 1, 4, 7, 7]
torch.randperm(n, generator=generator).tolist()
[5, 0, 1, 4, 9, 6, 3, 2, 8, 7]
'''
- 当 replacement 设置为 True 时,采样是有放回的,即同一个样本可能会被多次采样到,生成的索引可能存在重复。上述代码line48。
当 replacement 设置为 False 时,采样是无放回的,即每个样本只会被采样一次,生成的索引不会有重复。上述代码line50。
replacement默认为False。
否则使用SequentialSampler,顺序从数据集中进行采样。
class SequentialSampler(Sampler[int]):
***
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
根据数据的长度,顺序返回数据的索引。下标是0~数据长度-1
嗨,欢迎大家关注我的公众号《CV之路》,一起讨论问题,一起学习进步~。也欢迎大家关注我的GitHub仓库,我出的所有博文教程都是无偿分享的,只求个关注与Star~,多谢大家支持!
GitHub - gy-7/CV-Road (后续教程相关所有代码都会维护到此仓库)