PyTorch学习笔记:data.WeightedRandomSampler——数据权重概率采样
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)
功能:按给定的权重(概率) [ p 0 , p 1 , … , p n − 1 ] [p_0,p_1,\dots,p_{n-1}] [p0,p1,…,pn−1]对样本索引 [ 0 , 1 , … , n − 1 ] [0,1,\dots,n-1] [0,1,…,n−1]采样
输入:
weights
:采样权重,权重之和不要求为1,该权重需要与每个样本对应起来,即权重数量等于样本数量num_samples
:所采样本的数量,可以小于weights
的数量replacement
:采样策略,如果为True
,则代表使用替换采样策略,即可重复对一个样本进行采样;如果为False
,则表示不用替换采样策略,即一个样本最多只能被采一次generator
:采样过程中的生成器
代码案例
一般用法
from torch.utils.data import WeightedRandomSampler
sampler = WeightedRandomSampler([0.1, 0.6, 1.2, 2.9, 0.8, 0.4, 0.8, 1.0, 0.9], 8)
print([i for i in sampler])
输出
这里采样得到的都是样本的索引
[5, 4, 6, 7, 0, 4, 4, 6]
replacement
设为True
与False
的区别
from torch.utils.data import WeightedRandomSampler
sampler_t = WeightedRandomSampler([0.1, 0.6, 1.2, 2.9, 0.8, 0.4, 0.8, 1.0, 0.9], 8, replacement=True)
sampler_f = WeightedRandomSampler([0.1, 0.6, 1.2, 2.9, 0.8, 0.4, 0.8, 1.0, 0.9], 8, replacement=False)
print('sampler_t:', [i for i in sampler_t])
print('sampler_f:', [i for i in sampler_f])
输出
# replacement设为True时,会对同一样本多次采样
sampler_t: [6, 1, 6, 6, 3, 3, 8, 4]
# 否则每个样本只采样一次
sampler_f: [7, 0, 2, 4, 1, 3, 8, 5]
官方文档
torch.utils.data.WeightedRandomSampler:https://pytorch.org/docs/stable/data.html?highlight=sampler#torch.utils.data.WeightedRandomSampler
初步完稿于:2022年2月22日