提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
前言
在学习扩散模型的过程中,发现其许多的理论都与变分自编码器(VAE)有着密切的联系,所以学习一下VAE也有利于理解和掌握扩散模型。
自编码器(Auto-encoder)
在学习变分自编码器前,有必要先学习一下自编码器。自编码器由一个编码器(encoder)和一个解码器(decoder)组成。顾名思义,编码器捕捉输入数据的重要特征,将高维输入编码(低维);而解码器就是将低维编码解码回高维空间。
上图中,z表示低维空间中的特征,
θ
\theta
θ和
ϕ
\phi
ϕ分别表示编码器和解码器的神经网络的参数。网络的训练损失函数为
L
(
x
,
x
^
)
=
∥
x
−
x
^
∥
2
L(x,\hat x) = {\left\| {x - \hat x} \right\|^2}
L(x,x^)=∥x−x^∥2
虽然自编码器中有解码器,但是实际应用中却不能作为生成模型使用,因为自编码器中的隐变量是离散编码(不具有连续性),也就是说低维空间中的许多点经过解码器网络生成的结果均为没有意义的无效输出,而变分自编码器就对此进行了改进。
参考【李宏毅2021/2022春机器学习课程】-自编码器
变分自编码器(Variational Auto Encoder)
变分自编码器与自编码器不一样的地方就在于变分自编码器中隐变量是一个概率分布。
变分自编码器中编码器输出的是隐变量分布的参数,其中隐变量
z
∼
N
(
μ
,
σ
2
)
z\thicksim N(\mu ,{\sigma ^2})
z∼N(μ,σ2),而直接采样是一个随机过程,会导致梯度无法回传,所以如上图所示,使用了一个重参数采样技巧来获取隐变量。
重参数采样:把从有参数的分布中采样转变成从无参数的分布中采样。
例如,z 服从均值为 μ \mu μ,方差为 σ 2 \sigma ^2 σ2的高斯分布。如果对其进行重参数采样,可以写成 z = μ + σ ε z = \mu + \sigma \varepsilon z=μ+σε其中 ε \varepsilon ε服从标准正态分布。
梳理VAE的过程,作为一个生成模型,我们希望 z 服从标准正态分布,能够做到从标准正态分布中任意采样的 z 可以经过解码器生成一张具有实际语义的图像。那么 隐变量与实际图像又是如何对应的呢?
从VAE的整个设计思想出发,首先从数据集中采样 x ,然后从后验分布
q
(
z
∣
x
)
q(z|x)
q(z∣x)中采样得到隐变量 z ,而 z 经过解码器,可以恢复成原始样本 x 。所以使用模型
q
ϕ
(
z
∣
x
)
{q_\phi }(z|x)
qϕ(z∣x)来拟合真实后验分布,也称概率编码器,同样,使用
p
θ
(
x
∣
z
)
{p_\theta }(x|z)
pθ(x∣z)作为概率解码器。
那么模型到底如何训练,训练的损失函数又应该如何设置呢?
其实,对于生成模型来说,如果有一堆样本
{
x
1
,
x
2
.
.
.
x
N
}
\{ {x_1},{x_2}...{x_N}\}
{x1,x2...xN},它们都服从某一种概率分布
p
(
x
)
p(x)
p(x),如果我们能够直接求到这个分布的话,就可以直接从
p
(
x
)
p(x)
p(x)中采样生成图片。虽然不能直接求出这个概率,但可以通过隐变量将其表示出来:
p
(
x
)
=
∫
z
p
(
z
)
p
(
x
∣
z
)
d
z
p(x) = \int\limits_z {p(z)p(x|z)dz}
p(x)=z∫p(z)p(x∣z)dz原则上,我们希望这个分布越大越好,所以模型的优化目标是最大化真实数据分布的似然估计,即求下式的最大值
∑
1
N
p
(
x
i
)
\sum\nolimits_1^N {p({x_i})}
∑1Np(xi)
对于任意一个样本 x 的对数分布,可以表示为
log
p
(
x
)
=
∫
z
q
ϕ
(
z
∣
x
)
log
p
(
x
)
d
z
=
∫
z
q
ϕ
(
z
∣
x
)
log
p
(
z
,
x
)
q
ϕ
(
z
∣
x
)
q
ϕ
(
z
∣
x
)
p
(
z
∣
x
)
d
z
=
∫
z
q
ϕ
(
z
∣
x
)
log
p
(
z
,
x
)
q
ϕ
(
z
∣
x
)
d
z
+
∫
z
q
ϕ
(
z
∣
x
)
log
q
ϕ
(
z
∣
x
)
p
(
z
∣
x
)
d
z
=
∫
z
q
ϕ
(
z
∣
x
)
log
p
(
z
,
x
)
q
ϕ
(
z
∣
x
)
d
z
+
D
K
L
(
q
ϕ
(
z
∣
x
)
∣
∣
p
(
z
∣
x
)
)
\begin{aligned} \log p(x) &= \int\limits_z {{q_\phi }(z|x)\log p(x)dz} \\ &= \int\limits_z {{q_\phi }(z|x)\log {{p(z,x){q_\phi }(z|x)} \over {{q_\phi }(z|x)p(z|x)}}dz} \\ &= \int\limits_z {{q_\phi }(z|x)\log {{p(z,x)} \over {{q_\phi }(z|x)}}dz} + \int\limits_z {{q_\phi }(z|x)\log {{{q_\phi }(z|x)} \over {p(z|x)}}dz} \\ &= \int\limits_z {{q_\phi }(z|x)\log {{p(z,x)} \over {{q_\phi }(z|x)}}dz} + D_{KL}({q_\phi }(z|x)||p(z|x)) \\ \end{aligned}
logp(x)=z∫qϕ(z∣x)logp(x)dz=z∫qϕ(z∣x)logqϕ(z∣x)p(z∣x)p(z,x)qϕ(z∣x)dz=z∫qϕ(z∣x)logqϕ(z∣x)p(z,x)dz+z∫qϕ(z∣x)logp(z∣x)qϕ(z∣x)dz=z∫qϕ(z∣x)logqϕ(z∣x)p(z,x)dz+DKL(qϕ(z∣x)∣∣p(z∣x))上式中第二项是KL散度,根据KL散度的非负性,可得
log
p
(
x
i
)
≥
∫
z
q
ϕ
(
z
∣
x
)
log
p
(
z
,
x
)
q
ϕ
(
z
∣
x
)
d
z
=
∫
z
q
ϕ
(
z
∣
x
)
log
p
θ
(
x
∣
z
)
p
(
z
)
q
ϕ
(
z
∣
x
)
d
z
=
∫
z
q
ϕ
(
z
∣
x
)
log
p
(
z
)
q
ϕ
(
z
∣
x
)
d
z
+
∫
z
q
ϕ
(
z
∣
x
)
log
p
θ
(
x
∣
z
)
d
z
=
−
D
K
L
(
q
ϕ
(
z
∣
x
)
∣
∣
p
(
z
)
)
+
∫
z
q
ϕ
(
z
∣
x
)
log
p
θ
(
x
∣
z
)
d
z
\begin{aligned} \log p({x_i}) &\ge \int\limits_z {{q_\phi }(z|x)\log {{p(z,x)} \over {{q_\phi }(z|x)}}} dz \cr & = \int\limits_z {{q_\phi }(z|x)\log {{{p_\theta }(x|z)p(z)} \over {{q_\phi }(z|x)}}} dz \cr & = \int\limits_z {{q_\phi }(z|x)\log {{p(z)} \over {{q_\phi }(z|x)}}} dz + \int\limits_z {{q_\phi }(z|x)\log {p_\theta }(x|z)} dz \cr & = - {D_{KL}}({q_\phi }(z|x)||p(z)) + \int\limits_z {{q_\phi }(z|x)\log {p_\theta }(x|z)} dz \end{aligned}
logp(xi)≥z∫qϕ(z∣x)logqϕ(z∣x)p(z,x)dz=z∫qϕ(z∣x)logqϕ(z∣x)pθ(x∣z)p(z)dz=z∫qϕ(z∣x)logqϕ(z∣x)p(z)dz+z∫qϕ(z∣x)logpθ(x∣z)dz=−DKL(qϕ(z∣x)∣∣p(z))+z∫qϕ(z∣x)logpθ(x∣z)dz找到了对数似然估计的下界(不等式右侧),所以最大似然估计的问题就变成寻找下界的问题,这个下界就是变分下界,只要下界足够大,那么似然估计也就更大。将上式两项分别展开,因为我们的目标是要让隐变量分布接标准高斯分布,所以第一项可以展开为
−
D
K
L
(
q
ϕ
(
z
∣
x
)
∣
∣
p
(
z
)
)
=
−
D
K
L
(
N
(
μ
,
σ
2
)
∣
∣
N
(
0
,
I
)
)
=
1
2
(
log
σ
2
−
μ
2
−
σ
2
+
1
)
\begin{aligned} & - {D_{KL}}({q_\phi }(z|x)||p(z)) \cr & = - {D_{KL}}(N(\mu ,{\sigma ^2})||N(0,I)) \cr & = {1 \over 2}(\log {\sigma ^2} - {\mu ^2} - {\sigma ^2} + 1) \cr \end{aligned}
−DKL(qϕ(z∣x)∣∣p(z))=−DKL(N(μ,σ2)∣∣N(0,I))=21(logσ2−μ2−σ2+1)该项为约束项,对应的损失函数为
L
1
=
∑
μ
i
2
+
σ
i
2
−
log
σ
i
2
−
1
{L_1} = \sum {{\mu _i}^2 + {\sigma _i}^2 - \log {\sigma _i}^2 - 1}
L1=∑μi2+σi2−logσi2−1第二项可以展开为
∫
z
q
ϕ
(
z
∣
x
)
log
p
θ
(
x
∣
z
)
d
z
=
E
q
ϕ
(
z
∣
x
)
[
log
p
θ
(
x
∣
z
)
]
\int\limits_z {{q_\phi }(z|x)\log {p_\theta }(x|z)} dz = {E_{{q_\phi }(z|x)}}[\log {p_\theta }(x|z)]
z∫qϕ(z∣x)logpθ(x∣z)dz=Eqϕ(z∣x)[logpθ(x∣z)]这一项是重构项,简而言之这一项是让模型输出与输入尽可能相似,对应的损失函数为
L
2
=
∥
x
−
x
^
∥
2
{L_2} = {\left\| {x - \hat x} \right\|^2}
L2=∥x−x^∥2
对于单变量高斯分布 p ( μ 1 , σ 1 ) p(\mu_1 ,\sigma_1 ) p(μ1,σ1)和 q ( μ 2 , σ 2 ) q(\mu_2 ,\sigma_2 ) q(μ2,σ2),它们的KL散度为
D K L ( p ∣ ∣ q ) = log σ 2 σ 1 + σ 1 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 − 1 2 {D_{KL}}(p||q) = \log {{{\sigma _2}} \over {{\sigma _1}}} + {{\sigma _1^2 + {{({\mu _1} - {\mu _2})}^2}} \over {2\sigma _2^2}} - {1 \over 2} DKL(p∣∣q)=logσ1σ2+2σ22σ12+(μ1−μ2)2−21
总结
以上是变分自编码器的原理和公式推导,本文详细描述了变分自编码器的结构与损失函数的证明过程,如有不足欢迎指正。
参考文献
原论文:Auto-Encoding Variational Bayes
https://blog.csdn.net/m0_56942491/article/details/136265500
https://www.cnblogs.com/wxkang/p/17128108.html
https://www.bilibili.com/video/BV1m3411p7wDp=44&vd_source=d171090cd498e0b803d66330aa04930e