1.功能
根据概率分布来产生sample,产生的sample是输入tensor的index。
2.使用方法
可以通过传入一个K维向量probs或者logits来描述每个类别的概率分布,然后利用概率分布对类别采样,类别值在[0,K-1]之间,即向量元素的index。其中probs表示每个类别的概率,logits可以简单解释为将probs转换成对数概率。使用log_prob方法可以计算类别的对数概率。
3.代码案例
from torch.distributions import Categorical
import torch
log_prob = torch.tensor([-1.6094, -0.2231]) #给定对数概率
pdparams = log_prob
pd = Categorical(logits=pdparams)
sample = pd.sample() #0 or 1
pd.log_prob(sample ) #计算类别的对数概率
print(sample )
print(pd.log_prob(sample ))
4 验证
from torch.distributions import Categorical
import torch
import numpy as np
log_prob = torch.tensor([-1.6094, -0.2231]) #给定对数概率
pdparams = log_prob
pd = Categorical(logits=pdparams)
sample = pd.sample() #0 or 1
pd.log_prob(sample ) #计算类别的对数概率
print(sample )
print(pd.log_prob(sample ))
print(np.exp(-0.2231) + np.exp(-1.6094))
运行结果,如下:
为了避免log_prob本身指数之和概率为1,修改代码如下:
from torch.distributions import Categorical
import torch
import numpy as np
log_prob = torch.tensor([-1.6094, -0.31]) #给定对数概率
pdparams = log_prob
pd = Categorical(logits=pdparams)
sample = pd.sample() #0 or 1
pd.log_prob(sample ) #计算类别的对数概率
print(sample )
print(pd.log_prob(sample ))
print(np.exp(-0.2411) + np.exp(-1.5404))
运行结果,如下:
原本的对数概率会被规范化。