TORCH.UTILS.DATA.DATALOADER类的构造函数DataLoader中有一个参数sampler,其默认值为None。sampler参数和batch_sampler参数允许用户自己指定数据的加载顺序与采样方式。
torch.utils.data.
Sampler类是所有samplers的基类。其实现了以下两个方法
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
之所以没有实现__len()__方法是为了避免出现子类没有实现该方法而报错。在子类的实现中,可以选择性实现__len()__。
SequentialSampler
看名字就知道,该采样器实现的方法是顺序地采样数据集中地数据。源码如下:
class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
从源码中可以看出该采样方法就是按照索引顺序采样数据的。
RandomSampler
这是一个随机采样方法,有两种模式,一种是带重复的随机采样,一种是不带重复的随机采样。
class RandomSampler(Sampler):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.
Arguments:
data_source (Dataset): dataset to sample from
replacement (bool): samples are drawn with replacement if ``True``, default=``False``
num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
is supposed to be specified only when `replacement` is ``True``.
"""
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
if not isinstance(self.replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
if self._num_samples is not None and not replacement:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
def __len__(self):
return self.num_samples
文档表示,当我们使用带重复的采样时,我们可以指定抽取的样本数,当使用不带重复的采样方法时,则实现在整个数据集上随机采样,且采样样本的数量为数据集大小。torch.randint和torch.randperm方法分别返回带重复和不带重复的随机采样数组。
SubsetRandomSampler
class SubsetRandomSampler(Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
"""
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return (self.indices[i] for i in torch.randperm(len(self.indices)))
def __len__(self):
return len(self.indices)
从文档描述可以看出,该类的功能是不带重复的随机采样一个给定的索引列表。
WeightedRandomSampler
根据给定的概率来采样数据,根据源码,参考torch.multinomial函数即可。该类的功能也包括带重复与不带重复的版本,其在torch.multinomial中也已经实现。
class WeightedRandomSampler(Sampler):
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
Args:
weights (sequence) : a sequence of weights, not necessary summing up to one
num_samples (int): number of samples to draw
replacement (bool): if ``True``, samples are drawn with replacement.
If not, they are drawn without replacement, which means that when a
sample index is drawn for a row, it cannot be drawn again for that row.
Example:
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[0, 0, 0, 1, 0]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
"""
def __init__(self, weights, num_samples, replacement=True):
if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(replacement))
self.weights = torch.as_tensor(weights, dtype=torch.double)
self.num_samples = num_samples
self.replacement = replacement
def __iter__(self):
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
def __len__(self):
return self.num_samples
BatchSampler
生成小批量索引的另一种采样方式,同样可以设置batch大小,drop_last参数。
class BatchSampler(Sampler):
r"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler, batch_size, drop_last):
if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.format(sampler))
if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
可以看到,该类的一个输入时一个sampler,即一个采样索引列表。通过对该列表二次采样得到batch_sampler。注意:like sampler
, but returns a batch of indices at a time. Mutually exclusive with batch_size
, shuffle
, sampler
, and drop_last
.
翻译过来就是和batch_size
, shuffle
, sampler
, and drop_last
.互斥。
看过代码后,现在再来看官方文档里,Data Loading Order and Sampler这里的说明:
torch.utils.data.Sampler
classes are used to specify the sequence of indices/keys used in data loading. They represent iterable objects over the indices to datasets. E.g., in the common case with stochastic gradient decent (SGD), a Sampler
could randomly permute a list of indices and yield each one at a time, or yield a small number of them for mini-batch SGD.
在数据加载过程中,Sampler类被用来指定索引/键序列。它们表示数据集索引上可迭代的对象。在通常有SGD的情况下,一个Sampler会随机排列一个索引列表并一次生成一个索引列表或者当使用mini-batch SGD时生成一小批的索引列表。
A sequential or shuffled sampler will be automatically constructed based on the shuffle
argument to a DataLoader
. Alternatively, users may use the sampler
argument to specify a custom Sampler
object that at each time yields the next index/key to fetch.
基于DataLoader类的shuffle参数,一个顺序的或乱序的sampler会被自动创建。或者,用户可以使用sampler参数来指定一个用户Sampler对象来生成下一ge的索引/键。
A custom Sampler
that yields a list of batch indices at a time can be passed as the batch_sampler
argument. Automatic batching can also be enabled via batch_size
and drop_last
arguments. See the next section for more details on this.
用来生成批量索引的列表的Sampler可以通过batch_sampler参数传入。通过batch_size和drop_last可以实现自动批量化。