【pytorch学习】torch.multinomial

温馨提示:为了大家能很好的理解这个**多项式分布采用**,这里建议先看下面的这段内容

至于什么是多项式分布,这里不再赘述,不懂的同学可以去这里学习

多项式分布采样实现逻辑

以下这段内容是来自这里,这里目的是为了学习,如有侵权,联系我删除。

思路:
将每个概率值对应到[0,1]区间内的各个子区间(概率值大小体现在子区间的长度上),每次采样时,按照均匀分布随机生成一个[0,1]区间内的值,其落到哪个区间,则该区间概率值对应的元素即为被采样的元素;


算法:
1、先对概率值从大到小排列(不是必要过程,是便于加速的技巧,这样每次查找时优先检测随机数是否落在大概率的区间内,减少比较次数);
2、生成一个[0,1)区间内的随机数x (注意,Rand().nextDouble()得到的是[0,1)区间内的数,而wikipedia给出的算法中要求生成的是(0,1)区间的数);
3、将x与概率值列表中的各值pi逐个比较,并累加已比较过的前i-1个概率值的累加和sum:
若x落在[sum, sum+pi)区间内,则pi对应的元素被采样并返回 (注意区间的开闭应该参考步骤2中的情况);
否则,将pi累加入sum,继续将x与p(i+1)比较;

torch.multinomial

Parameters
	input (Tensor) – the input tensor containing probabilities
	num_samples (int) – number of samples to draw
	replacement (bool, optional) – whether to draw with replacement or not

参数说明

input : 权重(概率矩阵),也就是取每个值的概率。这里可以是一维度的,也可以是一个矩阵。操作都是按照行的概率分布进行的采样的,前者输出的是一个向量(shape为num_samples ),后者输出的是一个矩阵(shape为rows_matrix X num_samples)
num_samples : 采样的次数
**replacement (bool, optional) :**默认值值是False(即不放回采样)

注意事项

如果replacement =False,则num_samples必须小于input中非零元素的数目(如果是矩阵,则必须小于输入每一行中非零元素的最小数目)
这里给出官网的记录

>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
>>> torch.multinomial(weights, 2)
tensor([1, 2])
>>> torch.multinomial(weights, 4) # ERROR!
RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False,
not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320
>>> torch.multinomial(weights, 4, replacement=True)
tensor([ 2,  1,  1,  1])
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值