【深度学习】VAE(Variational Auto-Encoder)原理

下面的内容是从李宏毅2017年机器学习课程中关于VAE一节[1]中整理的。课程讲解的非常细致,再整理一遍方便理解查阅。


一、AE与VAE

AE(Auto-Encoder)是一个应用很广泛的机器学习方法。
主要内容即是:将输入(Input)经过编码器(encoder)压缩为一个编码(code),再通过解码器(decoder)将编码(code)解码为输出(Output)。
学习的目标即:要使得输出(Output)与输入(Input)越接近越好。
以输入为图像为例,结构图如下:
在这里插入图片描述
AE中间阶段生成的编码向量,并不是随机、没有意义的。编码中携带着与输入有关的信息,编码中的某些维度代表着输入数据的某些特征。例如生成人脸图像时,编码可以表示人脸表情、头发样子、是否有胡子等等。
VAE变分自动编码器作为AE的变体,它主要的变动是对编码(code)的生成上。编码(code)不再像AE中是唯一映射的,而是具有某种分布,使得编码(code)在某范围内波动时都可产生对应输出。借助下面这个例子进行理解:
在这里插入图片描述
如上图AE示意图,左侧是对满月图像编解码,右侧是对弦月图像编解码,而像中间的编码对解码器来说并不知道要生成何种图像。在VAE示意图中,左右两侧对图像编解码过程中,编码有不同程度的扰动(即图中noise),解码器利用扰动范围内的编码同样可以生成相应的图像,而对交界处的编码,编码器既想生成满月图像,又想生成弦月图像,为此做出折中,生成位于两者之间的图像。
这就是VAE一个较为直观的想法。


二、VAE原理

VAE是一个深度生成模型,其最终目的是生成出概率分布 P ( x ) P(x) P(x) x x x即输入数据。
在VAE中,我们通过高斯混合模型(Gaussian Mixture Model)来生成 P ( x ) P(x) P(x),也就是说 P ( x ) P(x) P(x)是由一系列高斯分布叠加而成的,每一个高斯分布都有它自己的参数 μ \mu μ σ \sigma σ
在这里插入图片描述
那我们借助一个变量 z ∼ N ( 0 , I ) z\sim N(0,I) zN(0,I)(注意 z z z是一个向量,生成自一个高斯分布),找一个映射关系,将向量 z z z映射成这一系列高斯分布的参数向量 μ ( z ) \mu (z) μ(z) σ ( z ) \sigma (z) σ(z)。有了这一系列高斯分布的参数我们就可以得到叠加后的 P ( x ) P(x) P(x)的形式,即 x ∣ z ∼ N ( μ ( z ) , σ ( z ) ) x|z \sim N \big(\mu(z), \sigma(z)\big) xzN(μ(z),σ(z))。(这里的“形式”仅是对某一个向量 z z z所得到的)。
那么要找的这个映射关系 P ( x ∣ z ) P(x|z) P(xz)怎么获得呢?就拿神经网络来做呗,只要神经元足够想要啥样的函数得不到呢。如下图形式:
在这里插入图片描述
输入向量 z z z,得到参数向量 μ ( z ) \mu (z) μ(z) σ ( z ) \sigma (z) σ(z)。这个映射关系是要在训练过程中更新NN权重得到的。这部分作用相当于最终的解码器(decoder)

对于某一个向量 z z z我们知道了如何找到 P ( x ) P(x) P(x)。那么对连续变量 z z z依据全概率公式有:
P ( x ) = ∫ z P ( z ) P ( x ∣ z ) d z P(x)=\int _{z} P(z)P(x|z)dz P(x)=zP(z)P(xz)dz但是很难直接计算积分部分,因为我们很难穷举出所有的向量 z z z用于计算积分。又因为 P ( x ) P(x) P(x)难以计算,那么真实的后验概率 P ( z ∣ x ) = P ( z ) P ( x ∣ z ) / P ( x ) P(z|x)=P(z)P(x|z)/P(x) P(zx)=P(z)P(xz)/P(x)同样是不容易计算的,这也就是为什么下文要引入 q ( z ∣ x ) q(z|x) q(zx)来近似真实后验概率 P ( z ∣ x ) P(z|x) P(zx)
因此我们用极大似然估计来估计 P ( x ) P(x) P(x),有似然函数 L L L L = ∑ x log ⁡ P ( x ) L=\sum_{x}\log P(x) L=xlogP(x)这里我们额外引入一个分布 q ( z ∣ x ) q(z|x) q(zx) z ∣ x ∼ N ( μ ′ ( x ) , σ ′ ( x ) ) z|x \sim N\big(\mu^\prime(x), \sigma^\prime(x)\big) zxN(μ(x),σ(x))。这个分布表示形式如下:
在这里插入图片描述这个分布同样是用一个神经网络来完成,向量 z z z根据NN输出的参数向量 μ ′ ( x ) \mu '(x) μ(x) σ ′ ( x ) \sigma '(x) σ(x)运算得到,注意这三个向量具有相同的维度。这部分作用相当于最终的编码器(encoder)
之后就开始推导了。
log ⁡ P ( x ) = ∫ z q ( z ∣ x ) log ⁡ P ( x ) d z ∵ ∫ z q ( z ∣ x ) d z = 1 = ∫ z q ( z ∣ x ) log ⁡ P ( z , x ) P ( z ∣ x ) d z = ∫ z q ( z ∣ x ) log ⁡ ( P ( z , x ) q ( z ∣ x ) ⋅ q ( z ∣ x ) P ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) log ⁡ q ( z ∣ x ) P ( z ∣ x ) d z + ∫ z q ( z ∣ x ) log ⁡ P ( z , x ) q ( z ∣ x ) d z = D K L ( q ( z ∣ x ) ∣ ∣ P ( z ∣ x ) ) + ∫ z q ( z ∣ x ) log ⁡ P ( z , x ) q ( z ∣ x ) d z ⪖ ∫ z q ( z ∣ x ) log ⁡ P ( z , x ) q ( z ∣ x ) d z ∵ D K L ( q ∣ ∣ P ) ⪖ 0 \begin{aligned} \log P(x)&=\int_{z}q(z|x)\log P(x)dz \qquad \because \int_{z}q(z|x)dz=1 \\ &=\int_{z} q(z|x)\log \frac{P(z, x)}{P(z|x)}dz \\ &=\int_z q(z|x)\log \big(\frac{P(z,x)}{q(z|x)} \cdot \frac{q(z|x)}{P(z|x)}\big)dz \\ &=\int_z q(z|x)\log \frac{q(z|x)}{P(z|x)}dz + \int_z q(z|x)\log \frac{P(z,x)}{q(z|x)}dz \\ &=D_{KL}\big(q(z|x)||P(z|x)\big) + \int_z q(z|x)\log \frac{P(z,x)}{q(z|x)}dz \\ &\eqslantgtr \int_z q(z|x)\log \frac{P(z,x)}{q(z|x)}dz \qquad \because D_{KL}(q||P) \eqslantgtr 0 \end{aligned} logP(x)=zq(zx)logP(x)dzzq(zx)dz=1=zq(zx)logP(zx)P(z,x)dz=zq(zx)log(q(zx)P(z,x)P(zx)q(zx))dz=zq(zx)logP(zx)q(zx)dz+zq(zx)logq(zx)P(z,x)dz=DKL(q(zx)P(zx))+zq(zx)logq(zx)P(z,x)dzzq(zx)logq(zx)P(z,x)dzDKL(qP)0
我们将 ∫ z q ( z ∣ x ) log ⁡ P ( z , x ) q ( z ∣ x ) d z \int_z q(z|x)\log \frac{P(z,x)}{q(z|x)}dz zq(zx)logq(zx)P(z,x)dz称为 log ⁡ P ( x ) \log P(x) logP(x) (variational)   lower   bound \textit{\textbf{(variational) lower bound}} (variational) lower bound (变分下界),简称为 L b L_b Lb最大化 L b L_b Lb就等价于最大化似然函数 L L L。那么接下来具体看 L b L_b Lb
L b = ∫ z q ( z ∣ x ) log ⁡ P ( z , x ) q ( z ∣ x ) d z = ∫ z q ( z ∣ x ) log ⁡ ( P ( z ) q ( z ∣ x ) ⋅ P ( x ∣ z ) ) d z = ∫ z q ( z ∣ x ) log ⁡ P ( z ) q ( z ∣ x ) d z + ∫ z q ( z ∣ x ) log ⁡ P ( x ∣ z ) d z = − D K L ( q ( z ∣ x ) ∣ ∣ P ( z ) ) + ∫ z q ( z ∣ x ) log ⁡ P ( x ∣ z ) d z = − D K L ( q ( z ∣ x ) ∣ ∣ P ( z ) ) + E q ( z ∣ x ) [ log ⁡ P ( x ∣ z ) ] \begin{aligned} L_b&=\int_z q(z|x)\log \frac{P(z,x)}{q(z|x)}dz \\ &=\int_z q(z|x)\log \big(\frac{P(z)}{q(z|x)} \cdot P(x|z) \big)dz \\ &=\int_z q(z|x)\log \frac{P(z)}{q(z|x)}dz + \int_z q(z|x)\log P(x|z)dz \\ &=-D_{KL}\big( q(z|x)||P(z)\big) + \int_z q(z|x)\log P(x|z)dz \\ &=-D_{KL}\big( q(z|x)||P(z)\big) + E_{q(z|x)}[\log P(x|z)] \end{aligned} Lb=zq(zx)logq(zx)P(z,x)dz=zq(zx)log(q(zx)P(z)P(xz))dz=zq(zx)logq(zx)P(z)dz+zq(zx)logP(xz)dz=DKL(q(zx)P(z))+zq(zx)logP(xz)dz=DKL(q(zx)P(z))+Eq(zx)[logP(xz)]
最大化 L b L_b Lb包括下面两部分:

  • minimizing \textit{minimizing} minimizing D K L ( q ( z ∣ x ) ∣ ∣ P ( z ) ) D_{KL}\big( q(z|x)||P(z)\big) DKL(q(zx)P(z))使后验分布近似值 q ( z ∣ x ) q(z|x) q(zx)接近先验分布 P ( z ) P(z) P(z)。也就是说通过 q ( z ∣ x ) q(z|x) q(zx)生成的编码 z z z不能太离谱,要与某个分布相当才行,这里是对中间编码生成起了限制作用。
    q ( z ∣ x ) q(z|x) q(zx) P ( z ) P(z) P(z)都是高斯分布时,推导式有([2]中Appendix B): D K L ( q ( z ∣ x ) ∣ ∣ P ( z ) ) = − 1 2 ∑ j J ( 1 + log ⁡ ( σ j ) 2 − ( μ j ) 2 − ( σ j ) 2 ) D_{KL}\big( q(z|x)||P(z)\big)=-\frac{1}{2}\sum_{j}^{J}\big( 1+\log (\sigma_{j})^2 - (\mu_j)^2-(\sigma_j)^2\big) DKL(q(zx)P(z))=21jJ(1+log(σj)2(μj)2(σj)2)其中 J J J表示向量 z z z的总维度数, σ j \sigma_j σj μ j \mu_j μj表示 q ( z ∣ x ) q(z|x) q(zx)输出的参数向量 σ \sigma σ μ \mu μ的第 j j j个元素。(这里的 σ \sigma σ μ \mu μ等于前文中 μ ′ ( x ) \mu '(x) μ(x) σ ′ ( x ) \sigma '(x) σ(x))
  • maximizing \textit{maximizing} maximizing E q ( z ∣ x ) [ log ⁡ P ( x ∣ z ) ] E_{q(z|x)}[\log P(x|z)] Eq(zx)[logP(xz)],即在给定编码器输出 q ( z ∣ x ) q(z|x) q(zx)下解码器输出 P ( x ∣ z ) P(x|z) P(xz)越大越好。这部分也就相当于最小化Reconstruction Error(重建损失)。

补充点:重建损失函数选择交叉熵损失还是平方差损失,是跟 P ( x ∣ z ) P(x|z) P(xz)形式有关的,再取对数似然。知乎回答[6]和专栏[7]中有进行讲解说明。引用[6]中用户Taffy lll的回答:

重建损失的数学形式是对数似然 log ⁡ p ( x ∣ z ) \log p(x|z) logp(xz),它的具体表达式和 p ( x ∣ z ) p(x|z) p(xz)相关。一般来说, p ( x ∣ z ) p(x|z) p(xz)的选取和 x x x的取值空间是密切相关的: 如果x是二值图像,这个概率一般用伯努利分布,而伯努利分布的对数似然就是binary cross entropy,可以调各大DL库里的BCE函数;如果x是彩色/灰度图像,这个概率取高斯分布,那么高斯分布的对数似然就是平方差。

由此我们可以得出VAE的原理图
在这里插入图片描述
通常忽略掉decoder输出的 σ ( x ) \sigma(x) σ(x)一项,仅要求 μ ( x ) \mu(x) μ(x) x x x越接近越好。
对某一输入数据 x x x来说,VAE的损失函数即:
min ⁡ L o s s V A E = D K L ( q ( z ∣ x ) ∣ ∣ P ( z ) ) − E q ( z ∣ x ) [ log ⁡ P ( x ∣ z ) ] \min Loss_{VAE} = D_{KL}\big( q(z|x)||P(z)\big) -E_{q(z|x)}[\log P(x|z)] minLossVAE=DKL(q(zx)P(z))Eq(zx)[logP(xz)]
附:
极大似然估计 P ( x ) P(x) P(x)的时候还有一种写法,即通过 P ( x ) = ∫ z P ( x , z ) d z P(x)=\int_z P(x,z)dz P(x)=zP(x,z)dz来推导。如图[3]:
在这里插入图片描述
里边有提到术语ELBOEvidence Lower BOund(证据下界),有兴趣的可以自行查阅了解(也就是上文提到的变分下界,不过ELBO叫法更普遍)。


三、reparameterization trick

由上文中VAE原理图可以看出, z ∼ q ( z ∣ x ) z \sim q(z|x) zq(zx),即编码 z z z是由分布 q ( z ∣ x ) q(z|x) q(zx)采样产生,而采样操作是不可微分的,因此反向传播做不了。[2]中提到了reparameterization trick来解决,借助[4]中的示意图理解下:
在这里插入图片描述
将上图左图原来的采样操作通过reparameterization trick变换为右图的形式。

我们引入一个外部向量 ϵ ∼ N ( 0 , I ) \bm\epsilon \sim N(\textbf{0}, \textbf{I}) ϵN(0,I),通过 z = μ + σ ⊙ ϵ \textbf{z}=\bm\mu + \bm\sigma \odot \bm\epsilon z=μ+σϵ计算编码 z \textbf{z} z ⊙ \odot 表示element-wise乘法, ϵ \bm\epsilon ϵ的每一维都服从标准高斯分布即 ϵ i ∼ N ( 0 , 1 ) \epsilon_i \sim N(0,1) ϵiN(0,1)),由此loss的梯度可以通过 μ \bm\mu μ σ \bm\sigma σ分支传递到encoder model处( ϵ \bm\epsilon ϵ并不需要梯度信息来更新)。

这里利用了这样一个事实[5]:

考虑单变量高斯分布,假设 z ∼ p ( z ∣ x ) = N ( μ , σ 2 ) z \sim p(z|x)=N(\mu, \sigma^2) zp(zx)=N(μ,σ2),从中采样一个 z z z,就相当于先从 N ( 0 , 1 ) N(0, 1) N(0,1)中采样一个 ϵ \epsilon ϵ,再令 z = μ + σ ⊙ ϵ z=\mu + \sigma \odot \epsilon z=μ+σϵ

最终的VAE实际形式如下图所示:

在这里插入图片描述


四、不足

VAE在产生新数据的时候是基于已有数据来做的,或者说是对已有数据进行某种组合而得到新数据的,它并不能生成或创造出新数据。另一方面是VAE产生的图像比较模糊。
而大名鼎鼎的GAN利用对抗学习的方式,既能生成新数据,也能产生较清晰的图像。后续的更是出现了很多种变形。


五、参考文献

[1] Unsupervised Learning: Deep Generative Model (2017/04/27)
[2] VAE原著Auto-encoding variational bayes
[3] VAE的三种不同推导方法
[4] https://www.jeremyjordan.me/variational-autoencoders/
[5] 变分自编码器VAE:原来是这么一回事
[6] 变分自编码器的重建损失为什么有人用交叉熵损失?有人用平方差?
[7] 再谈变分自编码器VAE:从贝叶斯观点出发
[8] posterior collapse 后验消失问题是什么

  • 19
    点赞
  • 75
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值