torch.distributions.categorical(probs)
class torch.distributions.categorical(probs)其作用是创建以参数probs为标准的类别分布,样本是来自“0,...,K-1”的整数,K是probs参数的长度。也就是说,按照probs的概率,在相应的位置进行采样,采样返回的是该位置的整数索引。如果probs是长度为K的一维列表,则每个元素是对该索引处的类进行采样的相对概率。如果probs是二维的,它被视为一批概率向量例如:probs = torch.FloatTensor([0.9,0.2])
原创
2020-05-23 13:27:47 ·
14428 阅读 ·
6 评论