pytorch源码 Sampler类与采样方式

为了从数据集中读取数据,pytorch提供了Sampler基类与多个子类实现不同方式的数据采样

1.基类Sampler

class Sampler(object):
    r"""Base class for all Samplers.
    """
    def __init__(self, data_source):
        pass
    def __iter__(self):
        raise NotImplementedError

所有的采样器都要继承Sampler类,必须实现的方法为__iter__(),返回可迭代对象。

2.Sequential Sampler

class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.
    Arguments:
        data_source (Dataset): dataset to sample from
    """
    def __init__(self, data_source):
        self.data_source = data_source
    def __iter__(self):
        return iter(range(len(self.data_source)))
    def __len__(self):
        return len(self.data_source)

在这里插入图片描述

3.随机采样RandomSampler

class RandomSampler(Sampler):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify :attr:`num_samples` to draw.
    Arguments:
        data_source (Dataset): dataset to sample from
        replacement (bool): samples are drawn with replacement if ``True``, default=``False``
        num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
            is supposed to be specified only when `replacement` is ``True``.
    """
    def __init__(self, data_source, replacement=False, num_samples=None):
        self.data_source = data_source
        # 这个参数控制的应该为是否重复采样
        self.replacement = replacement
        self._num_samples = num_samples
    # 省略类型检查
    @property
    def num_samples(self):
        # dataset size might change at runtime
        # 初始化时不传入num_samples的时候使用数据源的长度
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples
    def __iter__(self):
	    n = len(self.data_source)
	    if self.replacement:
	        # 生成的随机数是可能重复的
	        return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
    # 生成的随机数是不重复的
    	return iter(torch.randperm(n).tolist())
    # 返回数据集长度
    def __len__(self):
        return self.num_samples

在这里插入图片描述

4.子集随机采样Subset Random Sampler

class SubsetRandomSampler(Sampler):
    r"""Samples elements randomly from a given list of indices, without replacement.
    Arguments:
        indices (sequence): a sequence of indices
    """
    def __init__(self, indices):
        # 数据集的切片,比如划分训练集和测试集
        self.indices = indices
    def __iter__(self):
        # 以元组形式返回不重复打乱后的“数据”
        return (self.indices[i] for i in torch.randperm(len(self.indices)))
    def __len__(self):
        return len(self.indices)

Subset Random Sampler应该用于训练集、测试集和验证集的划分
在这里插入图片描述

5.批采样BatchSampler

class BatchSampler(Sampler):
    r"""Wraps another sampler to yield a mini-batch of indices."""
    def __init__(self, sampler, batch_size, drop_last):# ...省略类型检查
        # 定义使用何种采样器Sampler
        self.sampler = sampler
        self.batch_size = batch_size
        # 是否在采样个数小于batch_size时剔除本次采样
        self.drop_last = drop_last
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            # 如果采样个数和batch_size相等则本次采样完成
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        # for结束后在不需要剔除不足batch_size的采样个数时返回当前batch        
        if len(batch) > 0 and not self.drop_last:
            yield batch
    def __len__(self):
        # 在不进行剔除时,数据的长度就是采样器索引的长度
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

drop_last为“True”时,如果采样得到的数据个数小于batch_size则抛弃本个batch的数据。
在这里插入图片描述

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值