为了从数据集中读取数据,pytorch提供了Sampler基类与多个子类实现不同方式的数据采样
1.基类Sampler
class Sampler(object):
r"""Base class for all Samplers.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
所有的采样器都要继承Sampler类,必须实现的方法为__iter__(),返回可迭代对象。
2.Sequential Sampler
class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
3.随机采样RandomSampler
class RandomSampler(Sampler):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.
Arguments:
data_source (Dataset): dataset to sample from
replacement (bool): samples are drawn with replacement if ``True``, default=``False``
num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
is supposed to be specified only when `replacement` is ``True``.
"""
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
# 这个参数控制的应该为是否重复采样
self.replacement = replacement
self._num_samples = num_samples
# 省略类型检查
@property
def num_samples(self):
# dataset size might change at runtime
# 初始化时不传入num_samples的时候使用数据源的长度
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
if self.replacement:
# 生成的随机数是可能重复的
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
# 生成的随机数是不重复的
return iter(torch.randperm(n).tolist())
# 返回数据集长度
def __len__(self):
return self.num_samples
4.子集随机采样Subset Random Sampler
class SubsetRandomSampler(Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
"""
def __init__(self, indices):
# 数据集的切片,比如划分训练集和测试集
self.indices = indices
def __iter__(self):
# 以元组形式返回不重复打乱后的“数据”
return (self.indices[i] for i in torch.randperm(len(self.indices)))
def __len__(self):
return len(self.indices)
Subset Random Sampler应该用于训练集、测试集和验证集的划分
5.批采样BatchSampler
class BatchSampler(Sampler):
r"""Wraps another sampler to yield a mini-batch of indices."""
def __init__(self, sampler, batch_size, drop_last):、
# ...省略类型检查
# 定义使用何种采样器Sampler
self.sampler = sampler
self.batch_size = batch_size
# 是否在采样个数小于batch_size时剔除本次采样
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
# 如果采样个数和batch_size相等则本次采样完成
if len(batch) == self.batch_size:
yield batch
batch = []
# for结束后在不需要剔除不足batch_size的采样个数时返回当前batch
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
# 在不进行剔除时,数据的长度就是采样器索引的长度
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
drop_last为“True”时,如果采样得到的数据个数小于batch_size则抛弃本个batch的数据。