学习pytorch中的TORCH.UTILS.DATA.SAMPLER

TORCH.UTILS.DATA.DATALOADER类的构造函数DataLoader中有一个参数sampler,其默认值为None。sampler参数和batch_sampler参数允许用户自己指定数据的加载顺序与采样方式。

torch.utils.data.Sampler类是所有samplers的基类。其实现了以下两个方法

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

之所以没有实现__len()__方法是为了避免出现子类没有实现该方法而报错。在子类的实现中,可以选择性实现__len()__。

SequentialSampler

看名字就知道,该采样器实现的方法是顺序地采样数据集中地数据。源码如下:

class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.

    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

从源码中可以看出该采样方法就是按照索引顺序采样数据的。

RandomSampler

这是一个随机采样方法,有两种模式,一种是带重复的随机采样,一种是不带重复的随机采样。

class RandomSampler(Sampler):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify :attr:`num_samples` to draw.

    Arguments:
        data_source (Dataset): dataset to sample from
        replacement (bool): samples are drawn with replacement if ``True``, default=``False``
        num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
            is supposed to be specified only when `replacement` is ``True``.
    """

    def __init__(self, data_source, replacement=False, num_samples=None):
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples

        if not isinstance(self.replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(self.replacement))

        if self._num_samples is not None and not replacement:
            raise ValueError("With replacement=False, num_samples should not be specified, "
                             "since a random permute will be performed.")

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(self.num_samples))

    @property
    def num_samples(self):
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self):
        n = len(self.data_source)
        if self.replacement:
            return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
        return iter(torch.randperm(n).tolist())

    def __len__(self):
        return self.num_samples

文档表示,当我们使用带重复的采样时,我们可以指定抽取的样本数,当使用不带重复的采样方法时,则实现在整个数据集上随机采样,且采样样本的数量为数据集大小。torch.randinttorch.randperm方法分别返回带重复和不带重复的随机采样数组。

SubsetRandomSampler

class SubsetRandomSampler(Sampler):
    r"""Samples elements randomly from a given list of indices, without replacement.

    Arguments:
        indices (sequence): a sequence of indices
    """

    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return (self.indices[i] for i in torch.randperm(len(self.indices)))

    def __len__(self):
        return len(self.indices)

从文档描述可以看出,该类的功能是不带重复的随机采样一个给定的索引列表。

WeightedRandomSampler

根据给定的概率来采样数据,根据源码,参考torch.multinomial函数即可。该类的功能也包括带重复与不带重复的版本,其在torch.multinomial中也已经实现。

class WeightedRandomSampler(Sampler):
    r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).

    Args:
        weights (sequence)   : a sequence of weights, not necessary summing up to one
        num_samples (int): number of samples to draw
        replacement (bool): if ``True``, samples are drawn with replacement.
            If not, they are drawn without replacement, which means that when a
            sample index is drawn for a row, it cannot be drawn again for that row.

    Example:
        >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
        [0, 0, 0, 1, 0]
        >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
        [0, 1, 4, 3, 2]
    """

    def __init__(self, weights, num_samples, replacement=True):
        if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
                num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(num_samples))
        if not isinstance(replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(replacement))
        self.weights = torch.as_tensor(weights, dtype=torch.double)
        self.num_samples = num_samples
        self.replacement = replacement

    def __iter__(self):
        return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())

    def __len__(self):
        return self.num_samples

BatchSampler

生成小批量索引的另一种采样方式,同样可以设置batch大小,drop_last参数。

class BatchSampler(Sampler):
    r"""Wraps another sampler to yield a mini-batch of indices.

    Args:
        sampler (Sampler): Base sampler.
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``

    Example:
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """

    def __init__(self, sampler, batch_size, drop_last):
        if not isinstance(sampler, Sampler):
            raise ValueError("sampler should be an instance of "
                             "torch.utils.data.Sampler, but got sampler={}"
                             .format(sampler))
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

可以看到,该类的一个输入时一个sampler,即一个采样索引列表。通过对该列表二次采样得到batch_sampler。注意:like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_sizeshufflesampler, and drop_last.

翻译过来就是和batch_sizeshufflesampler, and drop_last.互斥。

 

看过代码后,现在再来看官方文档里,Data Loading Order and Sampler这里的说明:

 torch.utils.data.Sampler classes are used to specify the sequence of indices/keys used in data loading. They represent iterable objects over the indices to datasets. E.g., in the common case with stochastic gradient decent (SGD), a Sampler could randomly permute a list of indices and yield each one at a time, or yield a small number of them for mini-batch SGD.

在数据加载过程中,Sampler类被用来指定索引/键序列。它们表示数据集索引上可迭代的对象。在通常有SGD的情况下,一个Sampler会随机排列一个索引列表并一次生成一个索引列表或者当使用mini-batch SGD时生成一小批的索引列表。

A sequential or shuffled sampler will be automatically constructed based on the shuffle argument to a DataLoader. Alternatively, users may use the sampler argument to specify a custom Sampler object that at each time yields the next index/key to fetch.

基于DataLoader类的shuffle参数,一个顺序的或乱序的sampler会被自动创建。或者,用户可以使用sampler参数来指定一个用户Sampler对象来生成下一ge的索引/键。

A custom Sampler that yields a list of batch indices at a time can be passed as the batch_sampler argument. Automatic batching can also be enabled via batch_size and drop_last arguments. See the next section for more details on this.

用来生成批量索引的列表的Sampler可以通过batch_sampler参数传入。通过batch_size和drop_last可以实现自动批量化。

  • 4
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
SubsetRandomSampler是PyTorch的一个采样器类,用于在给定数据集的子集上进行随机采样。它可以用于创建mini-batch训练数据集和验证数据集。 在使用SubsetRandomSampler时,你需要提供一个索引列表,该列表表示数据集的样本索引。然后SubsetRandomSampler将根据这些索引随机选择样本。 下面是一个使用SubsetRandomSampler的示例: ```python import torch from torch.utils.data import DataLoader, Dataset from torch.utils.data.sampler import SubsetRandomSampler # 创建自定义数据集 class CustomDataset(Dataset): def __init__(self): self.data = [1, 2, 3, 4, 5] def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data) # 创建数据集实例 dataset = CustomDataset() # 设定训练集索引 train_indices = [0, 1, 2] # 设定验证集索引 val_indices = [3, 4] # 创建采样器实例 train_sampler = SubsetRandomSampler(train_indices) val_sampler = SubsetRandomSampler(val_indices) # 创建数据加载器 train_loader = DataLoader(dataset, batch_size=2, sampler=train_sampler) val_loader = DataLoader(dataset, batch_size=2, sampler=val_sampler) # 在训练集上进行迭代 for batch in train_loader: print(batch) # 在验证集上进行迭代 for batch in val_loader: print(batch) ``` 在上面的示例,我们创建了一个自定义的数据集CustomDataset,并使用SubsetRandomSampler将数据集划分为训练集和验证集。然后,我们可以使用DataLoader加载数据集,并通过迭代器访问数据集的mini-batch。 希望这个例子能帮助你理解SubsetRandomSampler的用法。如果还有其他问题,请继续提问。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值