pytorch的torch.distributions中可以定义正态分布
如下:
import torch
from torch.distributions import Normal
mean=torch.Tensor([0,2])
normal=Normal(mean,1)
sample()
sample()就是直接在定义的正太分布(均值为mean,标准差std是1)上采样:
c=normal.sample()
print("c:",c)
输出:
c: tensor([-1.3362, 3.1730])
rsample()
rsample()不是在定义的正太分布上采样,而是先对标准正太分布 N ( 0 , 1 ) N(0,1) N(0,1)进行采样,然后输出: m e a n + s t d × 采 样 值 mean+std\times采样值 mean+std×采样值
a=normal.rsample()
输出:
a: tensor([ 0.0530, 2.8396])
log_prob(value)
log_prob(value)是计算value在定义的正态分布(mean,1)中对应的概率的对数,正太分布概率密度函数是
f
(
x
)
=
1
2
π
σ
e
−
(
x
−
μ
)
2
2
σ
2
f(x)=\frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{(x-\mu)^2}{2\sigma^2}}
f(x)=2πσ1e−2σ2(x−μ)2,对其取对数可得
l
o
g
(
f
(
x
)
)
=
−
(
x
−
μ
)
2
2
σ
2
−
l
o
g
(
σ
)
−
l
o
g
(
2
π
)
log(f(x))=-\frac{(x-\mu)^2}{2\sigma^2}-log(\sigma)-log(\sqrt{2\pi})
log(f(x))=−2σ2(x−μ)2−log(σ)−log(2π)
这里我们通过对数概率还原其对应的真实概率:
print("c log_prob:",normal.log_prob(c).exp())
输出:
c log_prob: tensor([ 0.1634, 0.2005])