torch.distributions.Categorical()和sample

python

torch.distributions.Categorical()

Categorical()`的参数有三个,分别为`probs`,`logits`,`validate_args

输入参数是probs

比如传入probs=[0.4, 0.3, 0.2, 0.1],或者probs=[4.0, 3.0, 2.0, 1.0],代码是直接对传入的probs进行归一化处理,对每个数据除以传入数据的累加和得到归一化后的数值,归一化的数据累加和为1。通过公式表示为:

p j = p j Σ i = 1 n p i p_j=\frac{p_j}{\varSigma _{i=1}^{n}p_i} pj=Σi=1npipj

经过处理后的所有 p i p_i pi的累加和为1,即

Σ n i = 1 p i = 1 \underset{i=1}{\overset{n}{\varSigma}}p_i=1 i=1Σnpi=1

import torch
from torch.distributions import Categorical
probs = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(probs=probs)
print(pd.probs)  # tensor([0.4000, 0.3000, 0.2000, 0.1000])
print(pd)  # Categorical(probs: torch.Size([4]))
print(probs)  # tensor([4., 3., 2., 1.])

probs = torch.tensor([1, 2, 3, 4])
pd = Categorical(probs=probs)
print(pd.probs)  # tensor([0.1000, 0.2000, 0.3000, 0.4000])

传入logits

self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)

在这里插入图片描述

就是对logits中的每一个数据都减去其对数指数累加和,公式的最后一部分就是代码的具体实现。公式中减号后面的部分就是LogSumExp,看字面意思很形象。

import torch
from torch.distributions import Categorical
logit = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(logits=logit)
print(pd.logits)  # tensor([-0.4402, -1.4402, -2.4402, -3.4402])
print(pd)  # Categorical(logits: torch.Size([4]))

logit = torch.tensor([1, 2, 3, 4])
pd = Categorical(logits=logit)
pd.logits
print(pd.logits)  # tensor([-3.4402, -2.4402, -1.4402, -0.4402])


logit = torch.tensor([0.4, 0.3, 0.2, 0.1])
pd = Categorical(logits=logit)
pd.logits
print(pd.logits)  # tensor([-1.2425, -1.3425, -1.4425, -1.5425])

对上面公式进行验证 logits=[4, 3, 2, 1]

import math
logit = 4 - math.log(math.exp(4) + math.exp(3) + math.exp(2) + math.exp(1))
print(logits)  # -0.4401896985611957

.sample

DataFrame.sample(n=None, frac=None, replace=False, weights=None, random_state=None, axis=None)

n:这是一个可选参数, 由整数值组成, 并定义生成的随机行数。

frac:它也是一个可选参数, 由浮点值组成, 并返回浮点值*数据帧值的长度。不能与参数n一起使用。

replace:由布尔值组成。如果为true, 则返回带有替换的样本。替换的默认值为false。

权重:它也是一个可选参数, 由类似于str或ndarray的参数组成。默认值”无”将导致相等的概率加权。如果正在通过系列赛;它将与索引上的目标对象对齐。在采样对象中找不到的权重索引值将被忽略, 而在采样对象中没有权重的索值将被分配零权重。如果在轴= 0时正在传递DataFrame, 则返回0。它将接受列的名称。如果权重是系列;然后, 权重必须与被采样轴的长度相同。如果权重不等于1;它将被标准化为1的总和。权重列中的缺失值被视为零。权重栏中不允许无穷大。

random_state:它也是一个可选参数, 由整数或numpy.random.RandomState组成。如果值为int, 则为随机数生成器或numpy RandomState对象设置种子。

axis:它也是由整数或字符串值组成的可选参数。 0或”行”和1或”列”。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值