pytorch中数据采样方法Sampler源码解析

本文深入解析了PyTorch中的数据采样方法,包括SequentialSampler、RandomSampler、SubsetRandomSampler和WeightedRandomSampler。详细介绍了各个采样的工作原理和参数含义,并通过实例展示了它们的不同行为,特别是在有放回和无放回采样中的区别。对于WeightedRandomSampler,还讨论了权重对采样概率的影响。
摘要由CSDN通过智能技术生成

Sampler采样函数基类

torch.utils.data.Sampler(data_source)

所有采样器的基类。

每个采样器子类都必须提供一个__iter__()方法,提供一种遍历dataset元素索引的方法,以及一个返回迭代器长度的__len__()方法。

pytorch中提供的采样方法主要有SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler,关键是__iter__的实现.

下面用一个简单的例子来分析各个采样函数的源码以及

import torch
from torch.utils.data.sampler import *
import numpy as np

t = np.arange(10)

SequentialSampler顺序采样

torch.utils.data.SequentialSampler(data_source)

其中__iter__为:

iter(range(len(self.data_source)))

参数

  • data_source为数据集

所以SequentialSampler的功能是顺序逐个采样数据

for i in SequentialSampler(t):
    print(i,end=',')

输出:

0,1,2,3,4,5,6,7,8,9,

RandomSampler随机采样

torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)

其中__iter__为:

n = len(self.data_source)
if self.replacement:
    return iter(torch.randint(high=n, 
                              size=(self.num_samples,),
                              dtype=torch.int64).tolist
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值