论文阅读24 - VAE - Variational AutoEncoder (Auto-Encoding Variationl Bayes)

VAE

一篇讲的很好的博客

理论推导博客

论文原文

斯坦福课件

上面的博客已经很好很深入了,下面记录一下我个人的直观理解。具体理论移步上面的博客。这里只是作为日后使用时的快速查阅。不具有理论推导的严谨性。

1. 直观理解:

第一次接触VAE还是在World Model这篇论文。VAE主要由三部分组成:

  • Encoder 编码器,用来中间向量z分布,即 p ( z ∣ x ) p(z|x) p(zx)
  • z向量 Encoder的输出,Decoder的输入。可以当做降维之后的输入。 p ( z ) p(z) p(z)
  • Decoder 解码器,用来从z生成出原来的,即 p ( x ∣ z ) p(x|z) p(xz)

VAE训练好后,可以用中间变量z作为其他模型的输入World Model就是这么做的,这样Encoder就相当于一个降维的作用。也可以将Decoder作为生成器,生成和训练集类似的样例,这就和GAN的功能类似。

本质上,VAE就是我给一堆输入到编码器,解码器能输出同样分布的输出。

生成模型的难题就是判断生成分布真实分布相似度,因为我们只知道两者的采样结果,不知道它们的分布表达式。

KL散度的虽然能衡量两种分布的近似度,但是必须知道分布的表达式

我们的假设是 p ( z ∣ x ) p(z|x) p(zx)是高斯分布。这是VAE模型的重点,正因为这个假设,我们才设计成如下模型:

当然,如果 p ( z ∣ x ) p(z|x) p(zx)是高斯分布, p ( z ) p(z) p(z)也满足正态分布。推理如下(不区分积分与求和):

p ( z ) = ∑ x p ( z ∣ x ) p ( x ) = ∑ x N ( 0 , 1 ) p ( x ) = N ( 0 , 1 ) ∑ x p ( x ) = N ( 0 , 1 ) p(z) = \sum_x p(z|x)p(x) = \sum_xN(0,1)p(x) = N(0,1)\sum_xp(x) = N(0,1) p(z)=xp(zx)p(x)=xN(0,1)p(x)=N(0,1)xp(x)=N(0,1)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gIG4oJZ3-1605507771174)(24-VAE.assets/image-20201115160705282.png)]

图片来源

2. 结构

结构示意图如上图所示。

以图片为例:

对于均值方差计算模块:可能是多个卷积层池化层

生成器:可能是多个反卷积层

均值方差的计算则是全连接网络。

2.1 为什么要向标准正态分布看齐,以及如何实现的?

​ 我们的z是根据均值和方差采样而来,在这里方差相当于噪声,如果方差是0的话,则采样结果则一定是均值。我们通过最小化生成的 x ^ \hat x x^与输入的 x x x之间的距离,来进行训练。那么我们的方差网络会逐渐趋近于结果为0。这时就退化成了AutoEncoder。

​ VAE通过在损失函数中引入生成的高斯分布 N ( μ , σ 2 ) N(\mu,\sigma^2) N(μ,σ2)与标准的高斯分布 N ( 0 , 1 ) N(0,1) N(0,1)之间的KL散度,来让 p ( z ∣ x ) p(z|x) p(zx)的分布趋近于标准正态分布。

VAE相对于之前的AutoEncoder的一个显著提升就是它的生成能力。从正态分布中采样生成一个z,就可以生成一个比较合理的结果。而AutoEncoder不能保证中间的z向量是某一种分布,所以它对于没有见过的(训练过的)z生成能力比较差。

2.2 如何采样出z?

直接sample出z是不行的,采样的过程是不可导,没办法BP啊!!

解决办法就是:
z = μ + ϵ × σ z = \mu + \epsilon \times \sigma z=μ+ϵ×σ

ϵ \epsilon ϵ 是从N(0,1)中采样来的。

这种技巧叫做重参数。反向传播时候,需要让z能够分别对 μ \mu μ σ \sigma σ求偏导,而对于 ϵ \epsilon ϵ则不需要对他求导。故才采样出来也没关系。

损失函数:
L o s s ( θ ) = D ( x , x ^ ) + K L ( N ( μ , σ 2 ) ∣ ∣ N ( 0 , 1 ) ) Loss(\theta) = D(x,\hat x) + KL(N(\mu, \sigma^2) || N(0,1)) Loss(θ)=D(x,x^)+KL(N(μ,σ2)N(0,1))

D ( x , x ^ ) D(x,\hat x) D(x,x^)是输入样本与生成样本之间的距离,可以使均方误等。

KL部分的推导,对于一维情况:

KL的公式:
D K L ( p ∣ ∣ q ) = ∑ i = 1 N = p ( x i ) log ⁡ p ( x i ) q ( x i ) D_{KL}(p||q) = \sum_{i = 1}^N = p(x_i) \log \frac{p(x_i)}{q(x_i)} DKL(pq)=i=1N=p(xi)logq(xi)p(xi)

K L ( N ( μ , σ 2 ) ∣ ∣ N ( 0 , 1 ) ) = ∫ 1 2 π σ 2 exp ⁡ { − ( x − μ ) 2 2 σ 2 } × log ⁡ { 1 2 π σ 2 exp ⁡ ( − ( x − μ ) 2 / 2 σ ) 1 2 π exp ⁡ ( − x 2 / 2 ) } d x = 一 顿 猛 如 虎 的 化 简 = 1 2 ∫ 1 2 π σ 2 exp ⁡ { − ( x − μ ) 2 2 σ 2 } [ − log ⁡ σ 2 + x 2 − ( x − μ ) 2 / σ 2 ] d x KL(N(\mu, \sigma^2) || N(0,1)) = \int \frac{1}{\sqrt{2 \pi \sigma^2}} \exp\{\frac{-(x-\mu)^2}{2\sigma^2}\} \times \log \{\frac{\frac{1}{\sqrt{2 \pi \sigma^2}}\exp(-(x-\mu)^2/2\sigma^)} {\frac{1}{\sqrt{2 \pi}}\exp(-x^2/2)}\} dx\\ =一顿猛如虎的化简 \\ =\frac{1}{2} \int \frac{1}{\sqrt{2 \pi \sigma^2}} \exp\{\frac{-(x-\mu)^2}{2\sigma^2}\} [-\log \sigma^2 + x^2 - (x-\mu)^2/\sigma^2]dx KL(N(μ,σ2)N(0,1))=2πσ2 1exp{2σ2(xμ)2}×log{2π 1exp(x2/2)2πσ2 1exp((xμ)2/2σ)}dx==212πσ2 1exp{2σ2(xμ)2}[logσ2+x2(xμ)2/σ2]dx

积分结果计算:

可分成三个积分加和(就是分别乘以中括号里那三部分):

  • 第一个是 − log ⁡ σ 2 -\log \sigma^2 logσ2可作为常数提出来,剩下是标准正态分布的积分值为1, 故结果为 − log ⁡ σ 2 -\log \sigma^2 logσ2

  • 第二项是二阶矩,结果为 μ 2 + σ 2 \mu^2 + \sigma^2 μ2+σ2

  • 第三项是
    ∫ − ∞ + ∞ 1 2 π σ 2 exp ⁡ { − ( x − μ ) 2 2 σ 2 } ( − ( x − μ ) 2 / σ 2 ) d x = ∫ − ∞ + ∞ 1 2 π exp ⁡ { − ( x − μ ) 2 2 σ 2 } ( − ( x − μ ) 2 / σ 2 ) d ( x − μ ) σ = − 1 2 π ∫ − ∞ + ∞ e − 1 2 t 2 t 2 d t = − 1 2 π ∫ − ∞ + ∞ e − 1 2 t 2 t d t 2 2 = − 2 1 2 π ∫ 0 + ∞ e − m 2 m 1 / 2 d m = − 2 1 π ∫ 0 + ∞ e − m m 3 2 − 1 d m = − 2 1 π Γ ( 3 2 ) Γ ( 3 2 ) = Γ ( 1 2 + 1 ) = 1 2 Γ ( 1 / 2 ) = 1 2 π 所 以 上 述 积 分 = − 2 1 π × 1 2 π = − 1 \int_{-\infty}^{+\infty} \frac{1}{\sqrt{2 \pi \sigma^2}}\exp\{\frac{-(x-\mu)^2}{2\sigma^2}\} (-(x-\mu)^2/\sigma^2)dx \\ = \int_{-\infty}^{+\infty} \frac{1}{\sqrt{2 \pi }}\exp\{\frac{-(x-\mu)^2}{2\sigma^2}\} (-(x-\mu)^2/\sigma^2)d \frac{(x-\mu)}{\sigma} \\ = - \frac{1}{\sqrt{2\pi}} \int_{-\infty}^{+\infty} e^{-\frac{1}{2}t^2} t^2 dt \\ = -\frac{1}{\sqrt{2\pi}} \int_{-\infty}^{+\infty} e^{-\frac{1}{2}t^2} t d \frac{t^2}{2} \\ = -2\frac{1}{\sqrt{2\pi}}\int_{0}^{+\infty} e^{-m} \sqrt 2 m^{1/2} dm \\ = -2\frac{1}{\sqrt \pi} \int_{0}^{+\infty} e^{-m} m^{\frac{3}{2}-1} dm \\ = - 2\frac{1}{\sqrt \pi} \Gamma(\frac{3}{2}) \\ \Gamma(\frac{3}{2}) = \Gamma(\frac{1}{2}+1) = \frac{1}{2}\Gamma(1/2) = \frac{1}{2} \sqrt{\pi} \\ 所以上述积分=- 2\frac{1}{\sqrt \pi} \times\frac{1}{2} \sqrt{\pi} = -1 +2πσ2 1exp{2σ2(xμ)2}((xμ)2/σ2)dx=+2π 1exp{2σ2(xμ)2}((xμ)2/σ2)dσ(xμ)=2π 1+e21t2t2dt=2π 1+e21t2td2t2=22π 10+em2 m1/2dm=2π 10+emm231dm=2π 1Γ(23)Γ(23)=Γ(21+1)=21Γ(1/2)=21π =2π 1×21π =1

最终,KL散度为:
K L ( N ( μ , σ 2 ) ∣ ∣ N ( 0 , 1 ) ) = 1 2 ( − log ⁡ σ 2 + μ 2 + σ 2 − 1 ) KL(N(\mu, \sigma^2) || N(0,1)) = \frac{1}{2} (-\log \sigma^2 + \mu^2 + \sigma^2 -1) KL(N(μ,σ2)N(0,1))=21(logσ2+μ2+σ21)

上述只是针对一个维度。如果一共有j个维度,则需要把每个维度的KL散度都想相加。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值