下面的内容是从李宏毅2017年机器学习课程中关于VAE一节[1]中整理的。课程讲解的非常细致,再整理一遍方便理解查阅。
一、AE与VAE
AE(Auto-Encoder)是一个应用很广泛的机器学习方法。
主要内容即是:将输入(Input)经过编码器(encoder)压缩为一个编码(code),再通过解码器(decoder)将编码(code)解码为输出(Output)。
学习的目标即:要使得输出(Output)与输入(Input)越接近越好。
以输入为图像为例,结构图如下:
AE中间阶段生成的编码向量,并不是随机、没有意义的。编码中携带着与输入有关的信息,编码中的某些维度代表着输入数据的某些特征。例如生成人脸图像时,编码可以表示人脸表情、头发样子、是否有胡子等等。
VAE变分自动编码器作为AE的变体,它主要的变动是对编码(code)的生成上。编码(code)不再像AE中是唯一映射的,而是具有某种分布,使得编码(code)在某范围内波动时都可产生对应输出。借助下面这个例子进行理解:
如上图AE示意图,左侧是对满月图像编解码,右侧是对弦月图像编解码,而像中间的编码对解码器来说并不知道要生成何种图像。在VAE示意图中,左右两侧对图像编解码过程中,编码有不同程度的扰动(即图中noise),解码器利用扰动范围内的编码同样可以生成相应的图像,而对交界处的编码,编码器既想生成满月图像,又想生成弦月图像,为此做出折中,生成位于两者之间的图像。
这就是VAE一个较为直观的想法。
二、VAE原理
VAE是一个深度生成模型,其最终目的是生成出概率分布
P
(
x
)
P(x)
P(x),
x
x
x即输入数据。
在VAE中,我们通过高斯混合模型(Gaussian Mixture Model)来生成
P
(
x
)
P(x)
P(x),也就是说
P
(
x
)
P(x)
P(x)是由一系列高斯分布叠加而成的,每一个高斯分布都有它自己的参数
μ
\mu
μ和
σ
\sigma
σ。
那我们借助一个变量
z
∼
N
(
0
,
I
)
z\sim N(0,I)
z∼N(0,I)(注意
z
z
z是一个向量,生成自一个高斯分布),找一个映射关系,将向量
z
z
z映射成这一系列高斯分布的参数向量
μ
(
z
)
\mu (z)
μ(z)和
σ
(
z
)
\sigma (z)
σ(z)。有了这一系列高斯分布的参数我们就可以得到叠加后的
P
(
x
)
P(x)
P(x)的形式,即
x
∣
z
∼
N
(
μ
(
z
)
,
σ
(
z
)
)
x|z \sim N \big(\mu(z), \sigma(z)\big)
x∣z∼N(μ(z),σ(z))。(这里的“形式”仅是对某一个向量
z
z
z所得到的)。
那么要找的这个映射关系
P
(
x
∣
z
)
P(x|z)
P(x∣z)怎么获得呢?就拿神经网络来做呗,只要神经元足够想要啥样的函数得不到呢。如下图形式:
输入向量
z
z
z,得到参数向量
μ
(
z
)
\mu (z)
μ(z)和
σ
(
z
)
\sigma (z)
σ(z)。这个映射关系是要在训练过程中更新NN权重得到的。这部分作用相当于最终的解码器(decoder)。
对于某一个向量
z
z
z我们知道了如何找到
P
(
x
)
P(x)
P(x)。那么对连续变量
z
z
z依据全概率公式有:
P
(
x
)
=
∫
z
P
(
z
)
P
(
x
∣
z
)
d
z
P(x)=\int _{z} P(z)P(x|z)dz
P(x)=∫zP(z)P(x∣z)dz但是很难直接计算积分部分,因为我们很难穷举出所有的向量
z
z
z用于计算积分。又因为
P
(
x
)
P(x)
P(x)难以计算,那么真实的后验概率
P
(
z
∣
x
)
=
P
(
z
)
P
(
x
∣
z
)
/
P
(
x
)
P(z|x)=P(z)P(x|z)/P(x)
P(z∣x)=P(z)P(x∣z)/P(x)同样是不容易计算的,这也就是为什么下文要引入
q
(
z
∣
x
)
q(z|x)
q(z∣x)来近似真实后验概率
P
(
z
∣
x
)
P(z|x)
P(z∣x)。
因此我们用极大似然估计来估计
P
(
x
)
P(x)
P(x),有似然函数
L
L
L:
L
=
∑
x
log
P
(
x
)
L=\sum_{x}\log P(x)
L=x∑logP(x)这里我们额外引入一个分布
q
(
z
∣
x
)
q(z|x)
q(z∣x),
z
∣
x
∼
N
(
μ
′
(
x
)
,
σ
′
(
x
)
)
z|x \sim N\big(\mu^\prime(x), \sigma^\prime(x)\big)
z∣x∼N(μ′(x),σ′(x))。这个分布表示形式如下:
这个分布同样是用一个神经网络来完成,向量
z
z
z根据NN输出的参数向量
μ
′
(
x
)
\mu '(x)
μ′(x)和
σ
′
(
x
)
\sigma '(x)
σ′(x)运算得到,注意这三个向量具有相同的维度。这部分作用相当于最终的编码器(encoder)。
之后就开始推导了。
log
P
(
x
)
=
∫
z
q
(
z
∣
x
)
log
P
(
x
)
d
z
∵
∫
z
q
(
z
∣
x
)
d
z
=
1
=
∫
z
q
(
z
∣
x
)
log
P
(
z
,
x
)
P
(
z
∣
x
)
d
z
=
∫
z
q
(
z
∣
x
)
log
(
P
(
z
,
x
)
q
(
z
∣
x
)
⋅
q
(
z
∣
x
)
P
(
z
∣
x
)
)
d
z
=
∫
z
q
(
z
∣
x
)
log
q
(
z
∣
x
)
P
(
z
∣
x
)
d
z
+
∫
z
q
(
z
∣
x
)
log
P
(
z
,
x
)
q
(
z
∣
x
)
d
z
=
D
K
L
(
q
(
z
∣
x
)
∣
∣
P
(
z
∣
x
)
)
+
∫
z
q
(
z
∣
x
)
log
P
(
z
,
x
)
q
(
z
∣
x
)
d
z
⪖
∫
z
q
(
z
∣
x
)
log
P
(
z
,
x
)
q
(
z
∣
x
)
d
z
∵
D
K
L
(
q
∣
∣
P
)
⪖
0
\begin{aligned} \log P(x)&=\int_{z}q(z|x)\log P(x)dz \qquad \because \int_{z}q(z|x)dz=1 \\ &=\int_{z} q(z|x)\log \frac{P(z, x)}{P(z|x)}dz \\ &=\int_z q(z|x)\log \big(\frac{P(z,x)}{q(z|x)} \cdot \frac{q(z|x)}{P(z|x)}\big)dz \\ &=\int_z q(z|x)\log \frac{q(z|x)}{P(z|x)}dz + \int_z q(z|x)\log \frac{P(z,x)}{q(z|x)}dz \\ &=D_{KL}\big(q(z|x)||P(z|x)\big) + \int_z q(z|x)\log \frac{P(z,x)}{q(z|x)}dz \\ &\eqslantgtr \int_z q(z|x)\log \frac{P(z,x)}{q(z|x)}dz \qquad \because D_{KL}(q||P) \eqslantgtr 0 \end{aligned}
logP(x)=∫zq(z∣x)logP(x)dz∵∫zq(z∣x)dz=1=∫zq(z∣x)logP(z∣x)P(z,x)dz=∫zq(z∣x)log(q(z∣x)P(z,x)⋅P(z∣x)q(z∣x))dz=∫zq(z∣x)logP(z∣x)q(z∣x)dz+∫zq(z∣x)logq(z∣x)P(z,x)dz=DKL(q(z∣x)∣∣P(z∣x))+∫zq(z∣x)logq(z∣x)P(z,x)dz⪖∫zq(z∣x)logq(z∣x)P(z,x)dz∵DKL(q∣∣P)⪖0
我们将
∫
z
q
(
z
∣
x
)
log
P
(
z
,
x
)
q
(
z
∣
x
)
d
z
\int_z q(z|x)\log \frac{P(z,x)}{q(z|x)}dz
∫zq(z∣x)logq(z∣x)P(z,x)dz称为
log
P
(
x
)
\log P(x)
logP(x)的
(variational)
lower
bound
\textit{\textbf{(variational) lower bound}}
(variational) lower bound (变分下界),简称为
L
b
L_b
Lb。最大化
L
b
L_b
Lb就等价于最大化似然函数
L
L
L。那么接下来具体看
L
b
L_b
Lb,
L
b
=
∫
z
q
(
z
∣
x
)
log
P
(
z
,
x
)
q
(
z
∣
x
)
d
z
=
∫
z
q
(
z
∣
x
)
log
(
P
(
z
)
q
(
z
∣
x
)
⋅
P
(
x
∣
z
)
)
d
z
=
∫
z
q
(
z
∣
x
)
log
P
(
z
)
q
(
z
∣
x
)
d
z
+
∫
z
q
(
z
∣
x
)
log
P
(
x
∣
z
)
d
z
=
−
D
K
L
(
q
(
z
∣
x
)
∣
∣
P
(
z
)
)
+
∫
z
q
(
z
∣
x
)
log
P
(
x
∣
z
)
d
z
=
−
D
K
L
(
q
(
z
∣
x
)
∣
∣
P
(
z
)
)
+
E
q
(
z
∣
x
)
[
log
P
(
x
∣
z
)
]
\begin{aligned} L_b&=\int_z q(z|x)\log \frac{P(z,x)}{q(z|x)}dz \\ &=\int_z q(z|x)\log \big(\frac{P(z)}{q(z|x)} \cdot P(x|z) \big)dz \\ &=\int_z q(z|x)\log \frac{P(z)}{q(z|x)}dz + \int_z q(z|x)\log P(x|z)dz \\ &=-D_{KL}\big( q(z|x)||P(z)\big) + \int_z q(z|x)\log P(x|z)dz \\ &=-D_{KL}\big( q(z|x)||P(z)\big) + E_{q(z|x)}[\log P(x|z)] \end{aligned}
Lb=∫zq(z∣x)logq(z∣x)P(z,x)dz=∫zq(z∣x)log(q(z∣x)P(z)⋅P(x∣z))dz=∫zq(z∣x)logq(z∣x)P(z)dz+∫zq(z∣x)logP(x∣z)dz=−DKL(q(z∣x)∣∣P(z))+∫zq(z∣x)logP(x∣z)dz=−DKL(q(z∣x)∣∣P(z))+Eq(z∣x)[logP(x∣z)]
最大化
L
b
L_b
Lb包括下面两部分:
-
minimizing
\textit{minimizing}
minimizing
D
K
L
(
q
(
z
∣
x
)
∣
∣
P
(
z
)
)
D_{KL}\big( q(z|x)||P(z)\big)
DKL(q(z∣x)∣∣P(z)),使后验分布近似值
q
(
z
∣
x
)
q(z|x)
q(z∣x)接近先验分布
P
(
z
)
P(z)
P(z)。也就是说通过
q
(
z
∣
x
)
q(z|x)
q(z∣x)生成的编码
z
z
z不能太离谱,要与某个分布相当才行,这里是对中间编码生成起了限制作用。
当 q ( z ∣ x ) q(z|x) q(z∣x)和 P ( z ) P(z) P(z)都是高斯分布时,推导式有([2]中Appendix B): D K L ( q ( z ∣ x ) ∣ ∣ P ( z ) ) = − 1 2 ∑ j J ( 1 + log ( σ j ) 2 − ( μ j ) 2 − ( σ j ) 2 ) D_{KL}\big( q(z|x)||P(z)\big)=-\frac{1}{2}\sum_{j}^{J}\big( 1+\log (\sigma_{j})^2 - (\mu_j)^2-(\sigma_j)^2\big) DKL(q(z∣x)∣∣P(z))=−21j∑J(1+log(σj)2−(μj)2−(σj)2)其中 J J J表示向量 z z z的总维度数, σ j \sigma_j σj和 μ j \mu_j μj表示 q ( z ∣ x ) q(z|x) q(z∣x)输出的参数向量 σ \sigma σ和 μ \mu μ的第 j j j个元素。(这里的 σ \sigma σ和 μ \mu μ等于前文中 μ ′ ( x ) \mu '(x) μ′(x)和 σ ′ ( x ) \sigma '(x) σ′(x)) - maximizing \textit{maximizing} maximizing E q ( z ∣ x ) [ log P ( x ∣ z ) ] E_{q(z|x)}[\log P(x|z)] Eq(z∣x)[logP(x∣z)],即在给定编码器输出 q ( z ∣ x ) q(z|x) q(z∣x)下解码器输出 P ( x ∣ z ) P(x|z) P(x∣z)越大越好。这部分也就相当于最小化Reconstruction Error(重建损失)。
补充点:重建损失函数选择交叉熵损失还是平方差损失,是跟 P ( x ∣ z ) P(x|z) P(x∣z)形式有关的,再取对数似然。知乎回答[6]和专栏[7]中有进行讲解说明。引用[6]中用户Taffy lll的回答:
重建损失的数学形式是对数似然 log p ( x ∣ z ) \log p(x|z) logp(x∣z),它的具体表达式和 p ( x ∣ z ) p(x|z) p(x∣z)相关。一般来说, p ( x ∣ z ) p(x|z) p(x∣z)的选取和 x x x的取值空间是密切相关的: 如果x是二值图像,这个概率一般用伯努利分布,而伯努利分布的对数似然就是binary cross entropy,可以调各大DL库里的BCE函数;如果x是彩色/灰度图像,这个概率取高斯分布,那么高斯分布的对数似然就是平方差。
由此我们可以得出VAE的原理图:
通常忽略掉decoder输出的
σ
(
x
)
\sigma(x)
σ(x)一项,仅要求
μ
(
x
)
\mu(x)
μ(x)与
x
x
x越接近越好。
对某一输入数据
x
x
x来说,VAE的损失函数即:
min
L
o
s
s
V
A
E
=
D
K
L
(
q
(
z
∣
x
)
∣
∣
P
(
z
)
)
−
E
q
(
z
∣
x
)
[
log
P
(
x
∣
z
)
]
\min Loss_{VAE} = D_{KL}\big( q(z|x)||P(z)\big) -E_{q(z|x)}[\log P(x|z)]
minLossVAE=DKL(q(z∣x)∣∣P(z))−Eq(z∣x)[logP(x∣z)]
附:
极大似然估计
P
(
x
)
P(x)
P(x)的时候还有一种写法,即通过
P
(
x
)
=
∫
z
P
(
x
,
z
)
d
z
P(x)=\int_z P(x,z)dz
P(x)=∫zP(x,z)dz来推导。如图[3]:
里边有提到术语ELBO,Evidence Lower BOund(证据下界),有兴趣的可以自行查阅了解(也就是上文提到的变分下界,不过ELBO叫法更普遍)。
三、reparameterization trick
由上文中VAE原理图可以看出,
z
∼
q
(
z
∣
x
)
z \sim q(z|x)
z∼q(z∣x),即编码
z
z
z是由分布
q
(
z
∣
x
)
q(z|x)
q(z∣x)采样产生,而采样操作是不可微分的,因此反向传播做不了。[2]中提到了reparameterization trick来解决,借助[4]中的示意图理解下:
将上图左图原来的采样操作通过reparameterization trick变换为右图的形式。
我们引入一个外部向量 ϵ ∼ N ( 0 , I ) \bm\epsilon \sim N(\textbf{0}, \textbf{I}) ϵ∼N(0,I),通过 z = μ + σ ⊙ ϵ \textbf{z}=\bm\mu + \bm\sigma \odot \bm\epsilon z=μ+σ⊙ϵ计算编码 z \textbf{z} z( ⊙ \odot ⊙表示element-wise乘法, ϵ \bm\epsilon ϵ的每一维都服从标准高斯分布即 ϵ i ∼ N ( 0 , 1 ) \epsilon_i \sim N(0,1) ϵi∼N(0,1)),由此loss的梯度可以通过 μ \bm\mu μ和 σ \bm\sigma σ分支传递到encoder model处( ϵ \bm\epsilon ϵ并不需要梯度信息来更新)。
这里利用了这样一个事实[5]:
考虑单变量高斯分布,假设 z ∼ p ( z ∣ x ) = N ( μ , σ 2 ) z \sim p(z|x)=N(\mu, \sigma^2) z∼p(z∣x)=N(μ,σ2),从中采样一个 z z z,就相当于先从 N ( 0 , 1 ) N(0, 1) N(0,1)中采样一个 ϵ \epsilon ϵ,再令 z = μ + σ ⊙ ϵ z=\mu + \sigma \odot \epsilon z=μ+σ⊙ϵ。
最终的VAE实际形式如下图所示:
四、不足
VAE在产生新数据的时候是基于已有数据来做的,或者说是对已有数据进行某种组合而得到新数据的,它并不能生成或创造出新数据。另一方面是VAE产生的图像比较模糊。
而大名鼎鼎的GAN利用对抗学习的方式,既能生成新数据,也能产生较清晰的图像。后续的更是出现了很多种变形。
五、参考文献
[1] Unsupervised Learning: Deep Generative Model (2017/04/27)
[2] VAE原著Auto-encoding variational bayes
[3] VAE的三种不同推导方法
[4] https://www.jeremyjordan.me/variational-autoencoders/
[5] 变分自编码器VAE:原来是这么一回事
[6] 变分自编码器的重建损失为什么有人用交叉熵损失?有人用平方差?
[7] 再谈变分自编码器VAE:从贝叶斯观点出发
[8] posterior collapse 后验消失问题是什么