torch.multinomial(input, num_samples,replacement=False, out=None)
返回一个张量,每行包含从input相应行中定义的多项分布中抽取的num_samples个样本。
注意:输入input每行的值不需要总和为1 (这里我们用来做权重),但是必须非负且总和不能为0。
当抽取样本时,依次从左到右排列(第一个样本对应第一列)。
如果输入input是一个向量,输出out也是一个相同长度num_samples的向量。如果输入input是有 m行的矩阵,输出out是形如m×n的矩阵。
如果参数replacement 为 True, 则样本抽取可以重复。否则,一个样本在每行不能被重复抽取。
参数num_samples必须小于input长度(即,input的列数,如果是input是一个矩阵)。
参数:
- input (Tensor) – 包含概率值的张量
- num_samples (int) – 抽取的样本数
- replacement (bool, optional) – 布尔值,决定是否能重复抽取
- out (Tensor, optional) – 结果张量
例子:
>>> weights = torch.Tensor([0, 10, 3, 0]) # create a Tensor of weights
>>> torch.multinomial(weights, 4)
1
2
0
0
[torch.LongTensor of size 4]
>>> torch.multinomial(weights, 4, replacement=True)
1
2
1
2
[torch.LongTensor of size 4]