一、介绍
Categorical函数来自包 torch.distributions,官方定义的接口如下:
class torch.distributions.Categorical(probs)
作用是创建以参数probs为标准的类别分布,样本是来自 “0 … K-1” 的整数,其中 K 是probs参数的长度。也就是说,按照传入的probs中给定的概率,在相应的位置处进行取样,取样返回的是该位置的整数索引。
如果 probs
是长度为 K
的一维列表,则每个元素是对该索引处的类进行抽样的相对概率。
如果 probs
是二维的,它被视为一批概率向量。
二、使用示例
probs = torch.FloatTensor([[0.05,0.1,0.85],[0.05,0.05,0.9]])
dist = Categorical(probs)
print(dist)
# Categorical(probs: torch.Size([2, 3]))
index = dist.sample()
print(index.numpy())
# [2 2]