的采样方式_pytorch源码阅读(三)Sampler类与4种采样方式

本文深入探讨PyTorch的Sampler类,包括Sequential Sampler、Random Sampler、Subset Random Sampler和Weighted Random Sampler的原理与用法。通过实例解析了不同采样方式在数据读取和训练过程中的影响,强调了数据采样策略对深度学习模型优化的重要性。
摘要由CSDN通过智能技术生成

0080b84a24cbf2ea19865afa2d86f9f3.png

由于我们不能将大量数据一次性放入网络中进行训练,所以需要分批进行数据读取。这一过程涉及到如何从数据集中读取数据的问题,pytorch提供了Sampler基类【1】与多个子类实现不同方式的数据采样。子类包含:

  • Sequential Sampler(顺序采样)
  • Random Sampler(随机采样)
  • Subset Random Sampler(子集随机采样)
  • Weighted Random Sampler(加权随机采样)等等。

1、基类Sampler

class Sampler(object):
    r"""Base class for all Samplers.
    """
    def __init__(self, data_source):
        pass
    def __iter__(self):
        raise NotImplementedError

对于所有的采样器来说,都需要继承Sampler类,必须实现的方法为__iter__(),也就是定义迭代器行为,返回可迭代对象。除此之外,Sampler类并没有定义任何其它的方法。

2、顺序采样Sequential Sampler

class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.
    Arguments:
        data_source (Dataset): dataset to sample from
    """
    def __init__(self, data_source):
        self.data_source = data_source
    def __iter__(self):
        return iter(range(len(self.data_source)))
    def __len__(self):
        return len(self.data_source)

顺序采样类并没有定义过多的方法,其中初始化方法仅仅需要一个Dataset类对象作为参数。对于__len__()只负责返回数据源包含的数据个数;__iter__()方法负责返回一个可迭代对象,这个可迭代对象是由range产生的顺序数值序列,也就是说迭代是按照顺序进行的。前述几种方法都只需要self.data_source实现了__len__()方法,因为这几种方法都仅仅使用了len(self.data_source)函数。所以下面采用同样实现了__len__()的list类型来代替Dataset类型做测试:

# 定义数据和对应的采样器
data = list([17, 22, 3, 41, 8])
seq_sampler &
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值