理解出错之处望不吝指正。
今天听师兄给我们讲了VAE,觉得颇有收获,分享一下,希望大家批评指正。
生成模型
生成模型的目的是从一系列样本
x
=
{
x
1
,
x
2
,
.
.
.
,
x
m
}
x=\{x_1,x_2,...,x_m\}
x={x1,x2,...,xm}中学习
x
x
x的分布
p
(
x
)
p(x)
p(x),我们可以仿照EM算法,通过隐变量
z
z
z和生成函数
g
(
)
g()
g()来得到
x
^
=
g
(
z
)
\hat x=g(z)
x^=g(z),并尽可能的让
x
^
\hat {x}
x^接近
x
x
x。
上述方法有一个弊端,我们首先依据全概率公式将
p
(
x
)
p(x)
p(x)写成如下形式:
p
(
x
)
=
∑
p
(
x
∣
z
)
p
(
z
)
p(x)=\sum p(x|z)p(z)
p(x)=∑p(x∣z)p(z)
我们易获取到
p
(
z
)
p(z)
p(z),但是在
z
z
z确定的情况下,无法得知
p
(
x
∣
z
)
p(x|z)
p(x∣z),这意味着我们无法将 “利用
g
(
)
g()
g()函数生成的
x
^
i
\hat x_i
x^i” 与 “真实的
x
i
x_i
xi” 进行对应,如下图所示:
那么怎么解决这个问题呢,接着往下看~
VAE模型
在VAE模型中,我们解决了上述问题。我们可以从两个角度理解VAE模型,首先是角度1。
理解角度1
为了解决上述问题,将学习
p
(
x
)
p(x)
p(x)改为学习
p
(
z
∣
x
)
p(z|x)
p(z∣x)即可!
我们令
p
(
z
∣
x
)
∼
N
(
μ
,
σ
2
)
p(z|x)\sim N(\mu,\sigma^2)
p(z∣x)∼N(μ,σ2)(ps:为什么这个先验分布是正态分布呢?因为若是其他分布,后面计算KL散度的时候会导致分母为0)。则VAE的编码解码过程如下图所示:
更具体的,我们假设
p
(
z
∣
x
)
∼
N
(
0
,
1
)
p(z|x)\sim N(0,1)
p(z∣x)∼N(0,1),即:我们在编码过程中期望学到
μ
i
=
0
,
σ
i
=
1
\mu_i=0,\sigma_i=1
μi=0,σi=1,则VAE的训练过程中会产生以下两种情况(类似于对抗):
1.若
σ
i
2
→
1
\sigma_i^2\rightarrow1
σi2→1,此时加在
x
i
x_i
xi上的噪声大,会导致已有的解码能力效果变差,通过反向传播会使得
σ
i
2
→
0
\sigma_i^2\rightarrow0
σi2→0;
2.若
σ
i
2
→
0
\sigma_i^2\rightarrow0
σi2→0,此时加在
x
i
x_i
xi上的噪声小,会导致已有的解码能力效果变好,通过反向传播会使得
σ
i
2
→
1
\sigma_i^2\rightarrow1
σi2→1;
VAE模型的损失函数如下:
L
v
a
e
p
(
z
∣
x
)
=
l
o
s
s
(
x
,
x
^
)
+
K
L
[
N
(
μ
,
σ
2
)
∣
∣
N
(
0
,
1
)
]
L_{vae}^{p(z|x)}=loss(x,\hat x)+KL[N(\mu,\sigma^2)||N(0,1)]
Lvaep(z∣x)=loss(x,x^)+KL[N(μ,σ2)∣∣N(0,1)]
损失函数中第一部分
l
o
s
s
(
x
,
x
^
)
loss(x,\hat x)
loss(x,x^)代表真实数据
x
x
x与生成数据
x
^
\hat x
x^之间的误差,可以使用简单的logistics损失或者MSE损失。
损失函数中第二部分是一个KL散度,用于衡量编码过程中得到的分布是否接近我们设置的先验分布
N
(
0
,
1
)
N(0,1)
N(0,1),下面对这部分进行详细的剖析。
K L [ N ( μ , σ 2 ) ∣ ∣ N ( 0 , 1 ) ] KL[N(\mu,\sigma^2)||N(0,1)] KL[N(μ,σ2)∣∣N(0,1)]
= ∫ 1 2 π σ e x p { − ( x − μ ) 2 2 σ 2 } ⋅ log 1 2 π σ e x p { − ( x − μ ) 2 2 σ 2 } 1 2 π e x p { − x 2 2 } d x =\int \frac{1}{\sqrt{2\pi}\sigma}exp\{\frac{-(x-\mu)^2}{2\sigma^2}\}\centerdot\log \frac{\frac{1}{\sqrt{2\pi}\sigma}exp\{\frac{-(x-\mu)^2}{2\sigma^2}\}}{\frac{1}{\sqrt{2\pi}}exp\{\frac{-x^2}{2}\}}dx =∫2πσ1exp{2σ2−(x−μ)2}⋅log2π1exp{2−x2}2πσ1exp{2σ2−(x−μ)2}dx
= 1 2 ∫ 1 2 π σ e x p { − ( x − μ ) 2 2 σ 2 } ⋅ [ − log σ 2 + x 2 − ( x − μ ) 2 σ 2 ] d x =\frac{1}{2}\int \frac{1}{\sqrt{2\pi}\sigma}exp\{\frac{-(x-\mu)^2}{2\sigma^2}\}\centerdot[-\log \sigma^2+x^2-\frac{(x-\mu)^2}{\sigma^2}]dx =21∫2πσ1exp{2σ2−(x−μ)2}⋅[−logσ2+x2−σ2(x−μ)2]dx
= 1 2 ( − log σ 2 + μ 2 + σ 2 − 1 ) =\frac{1}{2}(-\log \sigma^2+\mu^2+\sigma^2-1) =21(−logσ2+μ2+σ2−1)
=
1
2
μ
2
+
1
2
(
σ
2
−
log
σ
2
−
1
)
=\frac{1}{2}\mu^2+\frac{1}{2}(\sigma^2-\log \sigma^2-1)
=21μ2+21(σ2−logσ2−1)
理解角度2
在理解角度1中,我们由于不能计算 p ( x ∣ z ) p(x|z) p(x∣z),所以将学习 p ( x ) p(x) p(x)改为了学习 p ( z ∣ x ) p(z|x) p(z∣x),但是我们忽略了 p ( x ) p(x) p(x)不仅可以使用全概率公式分解为 p ( x ) = ∑ p ( x ∣ z ) p ( z ) p(x)=\sum p(x|z)p(z) p(x)=∑p(x∣z)p(z),还可将 p ( x ) p(x) p(x)写为 p ( x ) = ∫ p ( x , z ) d z p(x)=\int p(x,z)dz p(x)=∫p(x,z)dz,在这种情况下,我们假设先验分布为 q ( x , z ) q(x,z) q(x,z),则我们的学习目标变为:令 p ( x , z ) p(x,z) p(x,z)无限趋近 q ( x , z ) q(x,z) q(x,z),如下所示:
K L [ p ( x , z ) ∣ ∣ q ( x , z ) ] KL[p(x,z)||q(x,z)] KL[p(x,z)∣∣q(x,z)]
= ∫ ∫ p ( x , z ) log p ( x , z ) q ( x , z ) d z d x =\int\int p(x,z)\log \frac{p(x,z)}{q(x,z)} dz dx =∫∫p(x,z)logq(x,z)p(x,z)dzdx
将 p ( x , z ) = p ^ ( x ) p ( z ∣ x ) p(x,z)=\hat p(x)p(z|x) p(x,z)=p^(x)p(z∣x)带入上式,其中 p ^ ( x ) \hat p(x) p^(x)代表利用已有的 x x x值通过估计得到的分布,可得:
= ∫ p ^ ( x ) [ ∫ p ( z ∣ x ) log p ^ ( x ) p ( z ∣ x ) q ( x , z ) d z ] d x =\int \hat p(x)[ \int p(z|x) \log \frac{\hat p(x)p(z|x)}{q(x,z)}dz] dx =∫p^(x)[∫p(z∣x)logq(x,z)p^(x)p(z∣x)dz]dx
= E x ∼ p ^ ( x ) [ ∫ p ( z ∣ x ) log p ^ ( x ) p ( z ∣ x ) q ( x , z ) d z ] =E_{x\sim\hat p(x)}[\int p(z|x) \log \frac {\hat p(x)p(z|x)}{q(x,z)}dz] =Ex∼p^(x)[∫p(z∣x)logq(x,z)p^(x)p(z∣x)dz]
将 q ( x , z ) = q ( z ) q ( x ∣ z ) q(x,z)=q(z)q(x|z) q(x,z)=q(z)q(x∣z)和 p ( z ∣ x ) log p ( z ∣ x ) q ( z ) = K L [ p ( z ∣ x ) ∣ ∣ q ( z ) ] p(z|x)\log \frac {p(z|x)}{q(z)}=KL[p(z|x)||q(z)] p(z∣x)logq(z)p(z∣x)=KL[p(z∣x)∣∣q(z)]带入上式,可得:
=
E
x
∼
p
^
(
x
)
{
E
z
∼
p
(
z
∣
x
)
[
−
log
q
(
x
∣
z
)
]
+
K
L
[
p
(
z
∣
x
)
∣
∣
q
(
z
)
]
}
=E_{x\sim\hat p(x)}\{E_{z\sim p(z|x)}[-\log q(x|z)] + KL[p(z|x)||q(z)]\}
=Ex∼p^(x){Ez∼p(z∣x)[−logq(x∣z)]+KL[p(z∣x)∣∣q(z)]}
我们可以发现,这样计算得到的公式中每一项可以和
L
v
a
e
p
(
z
∣
x
)
L_{vae}^{p(z|x)}
Lvaep(z∣x)中的每一项对应:
− log q ( x ∣ z ) ⟷ l o s s ( x , x ^ ) -\log q(x|z)\longleftrightarrow loss(x,\hat x) −logq(x∣z)⟷loss(x,x^)
K
L
[
p
(
z
∣
x
)
∣
∣
q
(
z
)
]
⟷
K
L
[
N
(
μ
,
σ
2
)
∣
∣
N
(
0
,
1
)
]
KL[p(z|x)||q(z)]\longleftrightarrow KL[N(\mu,\sigma^2)||N(0,1)]
KL[p(z∣x)∣∣q(z)]⟷KL[N(μ,σ2)∣∣N(0,1)]
原文中说
q
(
x
∣
z
)
q(x|z)
q(x∣z)可以为两种分布:1.伯努利分布(B);2.正态分布(N)
1.当
q
(
x
∣
z
)
∼
B
q(x|z)\sim B
q(x∣z)∼B时,容易看出这是交叉熵损失:
−
log
q
(
x
∣
z
)
=
∑
{
x
k
log
p
k
(
z
)
+
(
1
−
x
k
)
log
[
1
−
p
k
(
z
)
]
}
-\log q(x|z) = \sum \{x_k\log p_k(z)+(1-x_k)\log [1-p_k(z)]\}
−logq(x∣z)=∑{xklogpk(z)+(1−xk)log[1−pk(z)]}
2.当
q
(
x
∣
z
)
∼
N
q(x|z)\sim N
q(x∣z)∼N时,容易看出这是均方误差损失:
−
log
q
(
x
∣
z
)
=
1
2
σ
2
∣
∣
x
−
μ
k
∣
∣
2
-\log q(x|z) = \frac {1}{2\sigma_2}||x-\mu_k||^2
−logq(x∣z)=2σ21∣∣x−μk∣∣2