torch.distributions.Categorical 分类分布

1.功能

根据概率分布来产生sample,产生的sample是输入tensor的index。

2.使用方法

可以通过传入一个K维向量probs或者logits来描述每个类别的概率分布,然后利用概率分布对类别采样,类别值在[0,K-1]之间,即向量元素的index。其中probs表示每个类别的概率,logits可以简单解释为将probs转换成对数概率。使用log_prob方法可以计算类别的对数概率。

3.代码案例

from torch.distributions import Categorical
import torch 
log_prob = torch.tensor([-1.6094, -0.2231]) #给定对数概率
pdparams = log_prob 
pd = Categorical(logits=pdparams)

sample = pd.sample() #0 or 1
pd.log_prob(sample ) #计算类别的对数概率
print(sample )
print(pd.log_prob(sample ))

4 验证

from torch.distributions import Categorical
import torch
import numpy as np 
log_prob = torch.tensor([-1.6094, -0.2231]) #给定对数概率
pdparams = log_prob 
pd = Categorical(logits=pdparams)
 
sample = pd.sample() #0 or 1
pd.log_prob(sample ) #计算类别的对数概率
print(sample )
print(pd.log_prob(sample ))
print(np.exp(-0.2231) + np.exp(-1.6094))

运行结果,如下:

为了避免log_prob本身指数之和概率为1,修改代码如下:

from torch.distributions import Categorical
import torch
import numpy as np 
log_prob = torch.tensor([-1.6094, -0.31]) #给定对数概率
pdparams = log_prob 
pd = Categorical(logits=pdparams)
 
sample = pd.sample() #0 or 1
pd.log_prob(sample ) #计算类别的对数概率
print(sample )
print(pd.log_prob(sample ))
print(np.exp(-0.2411) + np.exp(-1.5404))

运行结果,如下:

原本的对数概率会被规范化。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值