torch.multinomial() 是 PyTorch 库中的一个函数,它主要用于从给定的多项分布中抽取样本。这个函数对于进行随机抽样操作非常有用,特别是在处理概率分布和执行诸如强化学习、自然语言处理等任务时。
具体来说,torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) 函数的参数意义如下:
input: 包含非负元素的张量,代表多项分布的概率值。每个元素都是对应类别的概率权重,通常需要对这个张量进行归一化处理,以确保所有元素之和为 1。
num_samples: 需要抽取的样本数量。这个值必须小于或等于input中的元素数量,除非设置了replacement=True,允许重复抽取。
replacement: 表示是否允许在抽取过程中替换。如果为True,则允许一个元素被多次抽取;如果为False,每个元素最多只能被抽取一次。
generator: 用于指定一个随机数生成器。
out: 用于指定输出张量。
例子:
import torch
a = torch.randn(5,4)
a = torch.exp(a)
a = a/a.sum(1,keepdim=True) # 归一化
print(torch.multinomial(a,2)
out:
tensor([[1, 3],
[1, 2],
[1, 3],
[0, 1],
[3, 2]])
含义是,一共有5个样本(batch=5),列代表了每一类的概率;
在这里,一共有5个样本,每个样本有4个类别,我们根据概率
在针对每一个样本抽取2个类别!
本次呢:第一个样本抽取了第一个和第三个类别,
第二个样本抽取了第一个和第二个类别,
以此类推!!!