VAE理论详解,从概率的角度出发,还原论文推导

需要了解一下AE的知识。并且需要熟悉掌握GAN的理论知识。网上太多VAE讲解的论文,很多很难讲明白,包括VAE编解码输入输出loss在代码里面是什么都没讲明白。我相信看完这个,自己都能用代码复现,欢迎讨论。

1 简介

参考文章:
https://zhuanlan.zhihu.com/p/682709613
https://arxiv.org/pdf/1312.6114
https://arxiv.org/pdf/1606.05908v2
如下图所示,生成模型的目标是在一个已只分布p(z)中随机采样z,经过网络G,生成的结果 x ~ = G ( z ) \widetilde{x}=G(z) x =G(z)是满足训练数据p(x)的分布,我们假定生成的分布为 p g p_g pg,训练样本的分布为 p d a t a p_{data} pdata,一个非常麻烦的事是我们不知道,也无法去知道 p d a t a p_{data} pdata的分布,无法去做损失函数求解G。你可以这样理解,我们生成网络的目的是要得到 p g = p d a t a p_g=p_{data} pg=pdata,如果我都知道 p d a t a p_{data} pdata,我直接在 p d a t a p_{data} pdata采样不就完事了,还需要生成网络干嘛?让我们来回顾一下GAN网络,GAN怎么做的呢,GAN网络结构引入了D判别器,可以去翻一下前面的GAN,你会发现所有的损失函数是在D网络出来结果的损失,进而去约束G网络,实际上根本没有去求解 p d a t a p_{data} pdata的分布,通过对D做损失优化,最终G网络生成的 p g p_g pg是等于 p d a t a p_{data} pdata的,不得不佩服D网络引入的巧妙性。
在这里插入图片描述

那么VAE是怎么做的,通过我们前面那么多介绍,想必应该很清楚了,单独只有一个G网络,根本是无法实现生成任务的。GAN是在后面加的判别器能更好的求解损失。那么能否在前面加一个什么网络,使我们的损失函数好做一些,能够求解呢,当然VAE便是如此,在前面加上一个解码网络,接下来我们看看VAE这个过程。
模型结构:先简单解释一下流程,Q是一个编码器,输出的结果是均值和方差,在这个均值方差的正太分布上采样一个z,输入解码网络P得到生成的结果。接下来将围绕这个结构来详细说一下VAE。
在这里插入图片描述

AE自编码器(Autoencoder),是把输入X编码到一个laten space中,通过一个低维向量来表示X。VAE变分自编码器(Variational AutoEncoders),laten space是满足正太分布。由于AE的laten space不是一个分布,无法从laten space采样。而如果想要有生成能力,VAE巧妙的使laten space满足正太分布,这样在正太分布上采样即完成了一个生成模型。
从模型结构我们很好理解VAE这样做是合理的,只要把X编码到一个特定分布的laten space中,从而在这个特定分布采样到解码网络即完成生成,重建损失就是AE的重建损失,为了使laten space满足特定的分布,在加上一个KL散度来约束编码器,而事实上VAE中的V(变分)就是因为VAE的推导就是因为用到了KL散度(进而也包含了变分法)。
这样整个训练和loss也就出来,因为laten space 的分布是我们假定的特定分布(如标准正太分布),因此KL散度是可求的,重建损失也就跟AE是一样的,只需要求输入和最终输出结果的距离即可,简直很完美。然而这些都真是我们想的是这样,背后的理论依据又是怎么样的。

2 理论推导

首先要解释的一点是,样本 X = { x ( i ) } 1 N X=\{x^{(i)}\}^N_1 X={x(i)}1N是独立通分布。首先作者定义了一种分布 p θ p_{\theta} pθ,参数为 θ {\theta} θ,输入x与隐变量z的关系可以表示为:
p θ ( z ) p_{\theta}(z) pθ(z)表示隐变量z的先验分布;
p θ ( x ∣ z ) p_{\theta}(x|z) pθ(xz)为释然(由果到因);
p θ ( z ∣ x ) p_{\theta}(z|x) pθ(zx)为后验(由因到果)。
接下来作者做了个假设,假设我已经知道了真实参数 θ ∗ {\theta}^* θ,分为两步:第一步从先验 p θ ∗ ( z ) p_{{\theta}^{*}}(z) pθ(z)中采样 z ( i ) z^{(i)} z(i);第二步,由 p θ ( x ∣ z = z i ) p_{\theta}(x|z=z^{i}) pθ(xz=zi)可得到 x ( i ) x^{(i)} x(i)。那么现在问题变成求解参数 θ {\theta} θ,我们就想到释然函数,最优的 θ {\theta} θ必然是最后生成 x ( i ) x^{(i)} x(i)的概率乘积最大,也就是最大释然函数,我们先把释然函数写出来(用log表示,按照作者原论文形式写出):
l o g ( p θ ( x ( i ) , . . . , x ( N ) ) ) = ∑ i = 1 N l o g p θ ( x ( i ) ) log(p_{\theta}(x^{(i)},...,x^{(N)})) = \sum_{i=1}^{N}logp_{\theta}(x^{(i)}) log(pθ(x(i),...,x(N)))=i=1Nlogpθ(x(i))
我们知道 p θ ( x ) = ∫ p θ ( z ) p θ ( x ∣ z ) d z p_{\theta}(x)=\int{p_{\theta}(z)}p_{\theta}(x|z)dz pθ(x)=pθ(z)pθ(xz)dz很难求解,上面的释然函数自然没法求解。并且后验 p θ ( z ∣ x ) = p θ ( x ∣ z ) p θ ( z ) / p θ ( x ) p_{\theta}(z|x)=p_{\theta}(x|z)p_{\theta}(z)/p_{\theta}(x) pθ(zx)=pθ(xz)pθ(z)/pθ(x)也难求解,怎么办呢。既然这么难解怎么办呢,那就不解了,交给编码网络吧。
引入识别模型,也就是编码网络得到 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx)近似真实分布 p θ ( z ∣ x ) p_{\theta}(z|x) pθ(zx),而往往用KL散度来描述这两个分布是否近似,因此有以下:
K L ( q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ∣ x ) ) = E q ϕ ( z ∣ x ) [ l o g ( q ϕ ( z ∣ x ) / p θ ( z ∣ x ) ) ] = E q ϕ ( z ∣ x ) [ l o g q ϕ ( z ∣ x ) − l o g p θ ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ l o g q ϕ ( z ∣ x ) − l o g ( p θ ( x ∣ z ) p θ ( z ) / p θ ( x ) ) ] = E q ϕ ( z ∣ x ) [ l o g q ϕ ( z ∣ x ) − l o g ( p θ ( x ∣ z ) − l o g p θ ( z ) + l o g p θ ( x ) ) ] = E q ϕ ( z ∣ x ) [ l o g q ϕ ( z ∣ x ) − l o g p θ ( z ) ] − E q ϕ ( z ∣ x ) [ l o g ( p θ ( x ∣ z ) ] + l o g p θ ( x ) = K L ( q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ) ) − E q ϕ ( z ∣ x ) [ l o g ( p θ ( x ∣ z ) ] + l o g p θ ( x ) KL(q_{\phi}(z|x)||p_{\theta}(z|x))=E_{q_{\phi}(z|x)}[log(q_{\phi}(z|x)/p_{\theta}(z|x))]\\ =E_{q_{\phi}(z|x)}[logq_{\phi}(z|x) - logp_{\theta}(z|x)]=E_{q_{\phi}(z|x)}[logq_{\phi}(z|x) - log(p_{\theta}(x|z)p_{\theta}(z)/p_{\theta}(x))]\\ =E_{q_{\phi}(z|x)}[logq_{\phi}(z|x) - log(p_{\theta}(x|z)-logp_{\theta}(z)+logp_{\theta}(x))]\\ =E_{q_{\phi}(z|x)}[logq_{\phi}(z|x) -logp_{\theta}(z)]-E_{q_{\phi}(z|x)}[ log(p_{\theta}(x|z)]+logp_{\theta}(x)\\ =KL(q_{\phi}(z|x)||p_{\theta}(z))-E_{q_{\phi}(z|x)}[ log(p_{\theta}(x|z)]+logp_{\theta}(x) KL(qϕ(zx)∣∣pθ(zx))=Eqϕ(zx)[log(qϕ(zx)/pθ(zx))]=Eqϕ(zx)[logqϕ(zx)logpθ(zx)]=Eqϕ(zx)[logqϕ(zx)log(pθ(xz)pθ(z)/pθ(x))]=Eqϕ(zx)[logqϕ(zx)log(pθ(xz)logpθ(z)+logpθ(x))]=Eqϕ(zx)[logqϕ(zx)logpθ(z)]Eqϕ(zx)[log(pθ(xz)]+logpθ(x)=KL(qϕ(zx)∣∣pθ(z))Eqϕ(zx)[log(pθ(xz)]+logpθ(x)
记: L ( θ , ϕ ; x ( i ) ) = − K L ( q ϕ ( z ∣ x ( i ) ) ∣ ∣ p θ ( z ) ) + E q ϕ ( z ∣ x ( i ) ) [ l o g ( p θ ( x ( i ) ∣ z ) ] . . . . . . . . . . ( 3 ) L(\theta,\phi;x^{(i)})=-KL(q_{\phi}(z|x^{(i)})||p_{\theta}(z))+E_{q_{\phi}(z|x^{(i)})}[ log(p_{\theta}(x^{(i)}|z)] ..........(3) L(θ,ϕ;x(i))=KL(qϕ(zx(i))∣∣pθ(z))+Eqϕ(zx(i))[log(pθ(x(i)z)]..........(3)
这就是文章中的公式(3),我们把公式(3)带入上面的式子就得到下面论文中的公式(1):
l o g p θ ( x ( i ) ) = K L ( q ϕ ( z ∣ x ( i ) ) ∣ ∣ p θ ( z ∣ x ( i ) ) ) + L ( θ , ϕ ; x ( i ) ) . . . . . . . . . . . . ( 1 ) logp_{\theta}(x^{(i)})=KL(q_{\phi}(z|x^{(i)})||p_{\theta}(z|x^{(i)}))+L(\theta,\phi;x^{(i)})............(1) logpθ(x(i))=KL(qϕ(zx(i))∣∣pθ(zx(i)))+L(θ,ϕ;x(i))............(1)
由于KL散度的非负性,最大化释然等价于最大化 L ( θ , ϕ ; x ( i ) ) L(\theta,\phi;x^{(i)}) L(θ,ϕ;x(i)),因此我们VAE的损失函数也就出来了,就是围绕公式(3)。
我们假定先验分布 p θ ( z ) = N ( 0 , 1 ) p_{\theta}(z)= N(0,1) pθ(z)=N(0,1)为标准正太分布, q ϕ ( z ∣ x ( i ) ) q_{\phi}(z|x^{(i)}) qϕ(zx(i))近似为 N ( z ; μ ( i ) , σ ( i ) I ) N(z;\mu^{(i)},\sigma^{(i)}I) N(z;μ(i),σ(i)I),左边散度那一项变为以下,具体推导可看原文或者网上的一些,这里直接给结果,其中j表示多维正太分布的维度,也就是编码输出的 μ , σ \mu,\sigma μ,σ维度
K L ( q ϕ ( z ∣ x ( i ) ) ∣ ∣ p θ ( z ) ) ≈ 1 / 2 ∑ j = 1 J ( 1 + l o g ( σ j ( i ) ) 2 − ( μ j ( i ) ) 2 + ( σ j ( i ) ) 2 ) KL(q_{\phi}(z|x^{(i)})||p_{\theta}(z))\approx 1/2\sum_{j=1}^{J}(1+log(\sigma_j^{(i)})^2-(\mu_j^{(i)})^2+(\sigma_j^{(i)})^2) KL(qϕ(zx(i))∣∣pθ(z))1/2j=1J(1+log(σj(i))2(μj(i))2+(σj(i))2)
这以上公式代入公式(3)得到,论文中的公式(10),注意和论文公式(10)有差别,后面一项没有做变换,后面细说,其中J表示z的维度,
L ( θ , ϕ ; x ( i ) ) = 1 / 2 ∑ j = 1 J ( 1 + l o g ( σ j ( i ) ) 2 − ( μ j ( i ) ) 2 + ( σ j ( i ) ) 2 ) + E q ϕ ( z ∣ x ( i ) ) [ l o g ( p θ ( x ( i ) ∣ z ) ] . . . . . . . . . . ( 10 ) L(\theta,\phi;x^{(i)})=1/2\sum_{j=1}^{J}(1+log(\sigma_j^{(i)})^2-(\mu_j^{(i)})^2+(\sigma_j^{(i)})^2)+E_{q_{\phi}(z|x^{(i)})}[ log(p_{\theta}(x^{(i)}|z)] ..........(10) L(θ,ϕ;x(i))=1/2j=1J(1+log(σj(i))2(μj(i))2+(σj(i))2)+Eqϕ(zx(i))[log(pθ(x(i)z)]..........(10)
从概率角度来讲,假设 p θ ( x ( i ) ∣ z ) p_{\theta}(x^{(i)}|z) pθ(x(i)z)也是满足正太分布,方差不变,那么什么时候这个概率最大呢,毫无疑问 x = μ x=\mu x=μ的时候最大,也就是生成模型最终生成的结果是 μ \mu μ并且等于输入X的时候概率最大,那我们就可以用mse损失来代替后一项。论文中给了两种假设,其中之一就是正太分布,另一种是伯努利分布,两种推导请参考https://kexue.fm/archives/5253这里给出了详细的推导。而我们只需用mse会更加简单一些。优化最大,前面加上-号变成优化最小。
至此我们可以把公式(10)改写成可求解的loss,如下,我命名为(*)式:
− L ( θ , ϕ ; x ( i ) ) = − 1 / 2 ∑ j = 1 J ( 1 + l o g ( σ j ( i ) ) 2 − ( μ j ( i ) ) 2 + ( σ j ( i ) ) 2 ) + ∣ ∣ x ( i ) − f ( z ( i ) ) ∣ ∣ 2 2 . . . . . . . . . . ( ∗ ) -L(\theta,\phi;x^{(i)})=-1/2\sum_{j=1}^{J}(1+log(\sigma_j^{(i)})^2-(\mu_j^{(i)})^2+(\sigma_j^{(i)})^2)+||x^{(i)}-f(z^{(i)})||_2^2 ..........(*) L(θ,ϕ;x(i))=1/2j=1J(1+log(σj(i))2(μj(i))2+(σj(i))2)+∣∣x(i)f(z(i))22..........()

到这里我们基本证明了我们最开始loss的猜想,并且给出了可以求解的loss。
还有一个问题由于 z ( i ) z^{(i)} z(i)是在 N ( μ ( i ) , σ ( i ) I ) N(\mu^{(i)},\sigma^{(i)}I) N(μ(i),σ(i)I)随机采样的,我们在代码中反向传播就无法计算了,就没法用torch中反向传播求梯度,导致编码器不可学习了,怎么办呢?论文中巧妙的用了等价的方法,大家都叫重参数技巧 (reparameterization trick)简称trick:
z ( i ) = μ ( i ) + σ ( i ) ⊙ ϵ ( i ) z^{(i)}=\mu^{(i)}+\sigma^{(i)}\odot \epsilon^{(i)} z(i)=μ(i)+σ(i)ϵ(i)

这样我们最终的loss,可以求解的,也可以做反向传播的loss也就出来了,就可以用到代码里面了,过程很复杂,结果非常完美。
− L ( θ , ϕ ; x ( i ) ) = − 1 / 2 ∑ j = 1 J ( 1 + l o g ( σ j ( i ) ) 2 − ( μ j ( i ) ) 2 + ( σ j ( i ) ) 2 ) + ∣ ∣ x ( i ) − f ( μ ( i ) + σ ( i ) ⊙ ϵ ( i ) ) ∣ ∣ 2 2 . . . . . . . . . . ( ∗ ∗ ) -L(\theta,\phi;x^{(i)})=-1/2\sum_{j=1}^{J}(1+log(\sigma_j^{(i)})^2-(\mu_j^{(i)})^2+(\sigma_j^{(i)})^2)+||x^{(i)}-f(\mu^{(i)}+\sigma^{(i)}\odot \epsilon^{(i)})||_2^2 ..........(**) L(θ,ϕ;x(i))=1/2j=1J(1+log(σj(i))2(μj(i))2+(σj(i))2)+∣∣x(i)f(μ(i)+σ(i)ϵ(i))22..........()

备注

公式都是手打的,有错误欢迎指出。
说一个概念,VAE中的V为什么叫变分,首先我们要知道公式(3)有个专业术语叫变分下界,又叫Evidence Lower Bound,简称ELBO。而这个公式由来是通过引入 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx),然后通过变分推段的手段,也就是我们利用散度来推导出公式(3)和(1)的过程。才得到ELBO把对释然函数的优化转换成优化ELBO,也就是我们现在这个loss,所以叫变分自编码器。

  • 27
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

idealmu

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值