文章目录
Sampler采样函数基类
torch.utils.data.Sampler(data_source)
所有采样器的基类。
每个采样器子类都必须提供一个__iter__()
方法,提供一种遍历dataset元素索引的方法,以及一个返回迭代器长度的__len__()
方法。
pytorch中提供的采样方法主要有SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler
,关键是__iter__
的实现.
下面用一个简单的例子来分析各个采样函数的源码以及
import torch
from torch.utils.data.sampler import *
import numpy as np
t = np.arange(10)
SequentialSampler顺序采样
torch.utils.data.SequentialSampler(data_source)
其中__iter__
为:
iter(range(len(self.data_source)))
参数
data_source
为数据集
所以SequentialSampler
的功能是顺序逐个采样数据
for i in SequentialSampler(t):
print(i,end=',')
输出:
0,1,2,3,4,5,6,7,8,9,
RandomSampler随机采样
torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)
其中__iter__
为:
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n,
size=(self.num_samples,),
dtype=torch.int64).tolist