pytorch WeightRandomSampler要提供两个参数

朋友torch,单独提供了一个sampler模块,用来对数据进行采样,常用的有随机采样器randomsampler,当shuffle的参数为true,系统自动调用这个采样器,实现打乱数据。默认的是sequential sampler,他会按照顺序一个一个进行采样,还有一个WeightRandomSampler,他会根据每个样本的权重选取数据,在样本比例不均衡问题中,可以用它进行重采样。

WeightRandomSampler要提供两个参数,每个样本的权重weights,共选取的样本总数num_samples,以及一个可选参数replacement,。权重越大的样本被选中的概率越大,待选取的 样本数一般小于全部的样本总数。replacement用于指定是否可以重复选取某一个样本,默认为true,即允许一个epoch中重复采样某一个数据。

replacement为true,会覆盖dataset的实际大小,即一个epoch返回的图片总数取决于sampler.num_samples

 

#在数据处理中,
from  torch.utils.data.sampler import   WeightedRandomSampler
#狗的图片被取出的概率是猫的两倍
#两类图片被取出的概率与weights的绝对大小无关,至于比值有关
wights=[2 if label==1 else 1 for data, label in datasets]
wights=[2,2,1,1,1,1,2,2]
sampler=WeightedRandomSampler(wights,num_samples=9,replacement=True)

 

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值