PyTorch学习笔记:data.WeightedRandomSampler——数据权重概率采样

PyTorch学习笔记:data.WeightedRandomSampler——数据权重概率采样

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)

功能:按给定的权重(概率) [ p 0 , p 1 , … , p n − 1 ] [p_0,p_1,\dots,p_{n-1}] [p0,p1,,pn1]样本索引 [ 0 , 1 , … , n − 1 ] [0,1,\dots,n-1] [0,1,,n1]采样

输入:

  • weights:采样权重,权重之和不要求为1,该权重需要与每个样本对应起来,即权重数量等于样本数量
  • num_samples:所采样本的数量,可以小于weights的数量
  • replacement:采样策略,如果为True,则代表使用替换采样策略,即可重复对一个样本进行采样;如果为False,则表示不用替换采样策略,即一个样本最多只能被采一次
  • generator:采样过程中的生成器

代码案例

一般用法

from torch.utils.data import WeightedRandomSampler

sampler = WeightedRandomSampler([0.1, 0.6, 1.2, 2.9, 0.8, 0.4, 0.8, 1.0, 0.9], 8)
print([i for i in sampler])

输出

这里采样得到的都是样本的索引

[5, 4, 6, 7, 0, 4, 4, 6]

replacement设为TrueFalse的区别

from torch.utils.data import WeightedRandomSampler

sampler_t = WeightedRandomSampler([0.1, 0.6, 1.2, 2.9, 0.8, 0.4, 0.8, 1.0, 0.9], 8, replacement=True)
sampler_f = WeightedRandomSampler([0.1, 0.6, 1.2, 2.9, 0.8, 0.4, 0.8, 1.0, 0.9], 8, replacement=False)
print('sampler_t:', [i for i in sampler_t])
print('sampler_f:', [i for i in sampler_f])

输出

# replacement设为True时,会对同一样本多次采样
sampler_t: [6, 1, 6, 6, 3, 3, 8, 4]
# 否则每个样本只采样一次
sampler_f: [7, 0, 2, 4, 1, 3, 8, 5]

官方文档

torch.utils.data.WeightedRandomSampler:https://pytorch.org/docs/stable/data.html?highlight=sampler#torch.utils.data.WeightedRandomSampler

初步完稿于:2022年2月22日

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

视觉萌新、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值