用变分推断统一理解生成模型(VAE、GAN、AAE、ALI)隐变量的艺术

前言:我小学开始就喜欢纯数学,后来也喜欢上物理,还学习过一段时间的理论物理,直到本科毕业时,我才慢慢进入机器学习领域。所以,哪怕在机器学习领域中,我的研究习惯还保留着数学和物理的风格:企图从最少的原理出发,理解、推导尽可能多的东西。这篇文章是我这个理念的结果之一,试图以变分推断作为出发点,来统一地理解深度学习中的各种模型,尤其是各种让人眼花缭乱的GAN。本文已经挂到arxiv上,需要读英文原稿的可以移步到《Variational Inference: A Unified Framework of Generative Models and Some Revelations》

下面是文章的介绍。其实,中文版的信息可能还比英文版要稍微丰富一些,原谅我这蹩脚的英语...

摘要:本文从一种新的视角阐述了变分推断,并证明了EM算法、VAE、GAN、AAE、ALI(BiGAN)都可以作为变分推断的某个特例。其中,论文也表明了标准的GAN的优化目标是不完备的,这可以解释为什么GAN的训练需要谨慎地选择各个超参数。最后,文中给出了一个可以改善这种不完备性的正则项,实验表明该正则项能增强GAN训练的稳定性。

近年来,深度生成模型,尤其是GAN,取得了巨大的成功。现在我们已经可以找到数十个乃至上百个GAN的变种。然而,其中的大部分都是凭着经验改进的,鲜有比较完备的理论指导。

本文的目标是通过变分推断来给这些生成模型建立一个统一的框架。首先,本文先介绍了变分推断的一个新形式,这个新形式其实在博客以前的文章中就已经介绍过,它可以让我们在几行字之内导出变分自编码器(VAE)和EM算法。然后,利用这个新形式,我们能直接导出GAN,并且发现标准GAN的loss实则是不完备的,缺少了一个正则项。如果没有这个正则项,我们就需要谨慎地调整超参数,才能使得模型收敛。

实际上,本文这个工作的初衷,就是要将GAN纳入到变分推断的框架下。目前看来,最初的意图已经达到了,结果让人欣慰。新导出的正则项实际上是一个副产品,并且幸运的是,在我们的实验中这个副产品生效了。

变分推断新解 #

假设xx为显变量,zz为隐变量,p̃(x)p~(x)xx的证据分布,并且有

qθ(x)=∫qθ(x,z)dz(1)(1)qθ(x)=∫qθ(x,z)dz


我们希望qθ(x)qθ(x)能逼近p̃(x)p~(x),所以一般情况下我们会去最大化似然函数

θ=argmaxθ∫p̃(x)logq(x)dx(2)(2)θ=arg⁡maxθ∫p~(x)log⁡q(x)dx


这也等价于最小化KL散度KL(p̃(x)‖q(x))KL(p~(x)‖q(x))

KL(p̃(x)‖q(x))=∫p̃(x)logp̃(x)q(x)dx(3)(3)KL(p~(x)‖q(x))=∫p~(x)log⁡p~(x)q(x)dx


但是由于积分可能难以计算,因此大多数情况下都难以直接优化。

 

变分推断中,首先引入联合分布p(x,z)p(x,z)使得p̃(x)=∫p(x,z)dzp~(x)=∫p(x,z)dz,而变分推断的本质,就是将边际分布的KL散度KL(p̃(x)‖q(x))KL(p~(x)‖q(x))改为联合分布的KL散度KL(p(x,z)‖q(x,z))KL(p(x,z)‖q(x,z))KL(q(x,z)‖p(x,z))KL(q(x,z)‖p(x,z)),而

KL(p(x,z)‖q(x,z))=KL(p̃(x)‖q(x))+∫p̃(x)KL(p(z|x)‖q(z|x))dx≥KL(p̃(x)‖q(x))(4)(4)KL(p(x,z)‖q(x,z))=KL(p~(x)‖q(x))+∫p~(x)KL(p(z|x)‖q(z|x))dx≥KL(p~(x)‖q(x))


意味着联合分布的KL散度是一个更强的条件(上界)。所以一旦优化成功,那么我们就得到q(x,z)→p(x,z)q(x,z)→p(x,z),从而∫q(x,z)dz→∫p(x,z)dz=p̃(x)∫q(x,z)dz→∫p(x,z)dz=p~(x),即∫q(x,z)dz∫q(x,z)dz成为了真实分布p̃(x)p~(x)的一个近似。

 

当然,我们本身不是为了加强条件而加强,而是因为在很多情况下,KL(p(x,z)‖q(x,z))KL(p(x,z)‖q(x,z))KL(q(x,z)‖p(x,z))KL(q(x,z)‖p(x,z))往往比KL(p̃(x)‖q(x))KL(p~(x)‖q(x))更加容易计算。所以变分推断是提供了一个可计算的方案。

VAE和EM算法 #

由上述关于变分推断的新理解,我们可以在几句话内导出两个基本结果:变分自编码器和EM算法。这部分内容,实际上在《从最大似然到EM算法:一致的理解方式》《变分自编码器(二):从贝叶斯观点出发》已经详细介绍过了。这里用简单几句话重提一下。

VAE #

在VAE中,我们设q(x,z)=q(x|z)q(z),p(x,z)=p̃(x)p(z|x)q(x,z)=q(x|z)q(z),p(x,z)=p~(x)p(z|x),其中q(x|z),p(z|x)q(x|z),p(z|x)带有未知参数的高斯分布而q(z)q(z)是标准高斯分布。最小化的目标是

KL(p(x,z)‖q(x,z))=∬p̃(x)p(z|x)logp̃(x)p(z|x)q(x|z)q(z)dxdz(5)(5)KL(p(x,z)‖q(x,z))=∬p~(x)p(z|x)log⁡p~(x)p(z|x)q(x|z)q(z)dxdz


其中logp̃(x)log⁡p~(x)没有包含优化目标,可以视为常数,而对p̃(x)p~(x)的积分则转化为对样本的采样,从而

𝔼x∼p̃(x)[−∫p(z|x)logq(x|z)dz+KL(p(z|x)‖q(z))](6)(6)Ex∼p~(x)[−∫p(z|x)log⁡q(x|z)dz+KL(p(z|x)‖q(z))]


因为q(x|z),p(z|x)q(x|z),p(z|x)为带有神经网络的高斯分布,这时候KL(p(z|x)‖q(z))KL(p(z|x)‖q(z))可以显式地算出,而通过重参数技巧来采样一个点完成积分∫p(z|x)logq(x|z)dz∫p(z|x)log⁡q(x|z)dz的估算,可以得到VAE最终要最小化的loss:

𝔼x∼p̃(x)[−logq(x|z)+KL(p(z|x)‖q(z))](7)(7)Ex∼p~(x)[−log⁡q(x|z)+KL(p(z|x)‖q(z))]

EM算法 #

在VAE中我们对后验分布做了约束,仅假设它是高斯分布,所以我们优化的是高斯分布的参数。如果不作此假设,那么直接优化原始目标(5)(5),在某些情况下也是可操作的,但这时候只能采用交替优化的方式:先固定p(z|x)p(z|x),优化q(x|z)q(x|z),那么就有

q(x|z)=argmaxq(x|z)𝔼x∼p̃(x)[∫p(z|x)logq(x,z)dz](8)(8)q(x|z)=arg⁡maxq(x|z)Ex∼p~(x)[∫p(z|x)log⁡q(x,z)dz]


完成这一步后,我们固定q(x,z)q(x,z),优化p(z|x)p(z|x),先将q(x|z)q(z)q(x|z)q(z)写成q(z|x)q(x)q(z|x)q(x)的形式:

q(x)=∫q(x|z)q(z)dz,q(z|x)=q(x|z)q(z)q(x)(9)(9)q(x)=∫q(x|z)q(z)dz,q(z|x)=q(x|z)q(z)q(x)


那么有

p(z|x)===argminp(z|x)𝔼x∼p̃(x)[∫p(z|x)logp(z|x)q(z|x)q(x)dz]argminp(z|x)𝔼x∼p̃(x)[KL(p(z|x)‖q(z|x))−logq(x)]argminp(z|x)𝔼x∼p̃(x)[KL(p(z|x)‖q(z|x))](10)(10)p(z|x)=arg⁡minp(z|x)Ex∼p~(x)[∫p(z|x)log⁡p(z|x)q(z|x)q(x)dz]=arg⁡minp(z|x)Ex∼p~(x)[KL(p(z|x)‖q(z|x))−log⁡q(x)]=arg⁡minp(z|x)Ex∼p~(x)[KL(p(z|x)‖q(z|x))]


由于现在对p(z|x)p(z|x)没有约束,因此可以直接让p(z|x)=q(z|x)p(z|x)=q(z|x)使得loss等于0。也就是说,p(z|x)p(z|x)有理论最优解:

p(z|x)=q(x|z)q(z)∫q(x|z)q(z)dz(11)(11)p(z|x)=q(x|z)q(z)∫q(x|z)q(z)dz


(8),(11)(8),(11)的交替执行,构成了EM算法的求解步骤。这样,我们从变分推断框架中快速得到了EM算法。

 

变分推断下的GAN #

在这部分内容中,我们介绍了一般化的将GAN纳入到变分推断中的方法,这将引导我们得到GAN的新理解,以及一个有效的正则项。

一般框架 #

同VAE一样,GAN也希望能训练一个生成模型q(x|z)q(x|z),来将q(z)=N(z;0,I)q(z)=N(z;0,I)映射为数据集分布p̃(x)p~(x),不同于VAE中将q(x|z)q(x|z)选择为高斯分布,GAN的选择是

q(x|z)=δ(x−G(z)),q(x)=∫q(x|z)q(z)dz(12)(12)q(x|z)=δ(x−G(z)),q(x)=∫q(x|z)q(z)dz


其中δ(x)δ(x)是狄拉克δδ函数,G(z)G(z)即为生成器的神经网络。

 

一般我们会认为zz是一个隐变量,但由于δδ函数实际上意味着单点分布,因此可以认为xxzz的关系已经是一一对应的,所以zzxx的关系已经“不够随机”,在GAN中我们认为它不是隐变量(意味着我们不需要考虑后验分布p(z|x)p(z|x))。

事实上,在GAN中仅仅引入了一个二元的隐变量yy来构成联合分布

q(x,y)={p̃(x)p1,y=1q(x)p0,y=0(13)(13)q(x,y)={p~(x)p1,y=1q(x)p0,y=0


这里p1=1−p0p1=1−p0描述了一个二元概率分布,我们直接取p1=p0=1/2p1=p0=1/2。另一方面,我们设p(x,y)=p(y|x)p̃(x)p(x,y)=p(y|x)p~(x)p(y|x)p(y|x)是一个条件伯努利分布。而优化目标是另一方向的KL(q(x,y)‖p(x,y))KL(q(x,y)‖p(x,y))

KL(q(x,y)‖p(x,y))=∼∫p̃(x)p1logp̃(x)p1p(1|x)p̃(x)dx+∫q(x)p0logq(x)p0p(0|x)p̃(x)dx∫p̃(x)log1p(1|x)dx+∫q(x)logq(x)p(0|x)p̃(x)dx(14)(14)KL(q(x,y)‖p(x,y))=∫p~(x)p1log⁡p~(x)p1p(1|x)p~(x)dx+∫q(x)p0log⁡q(x)p0p(0|x)p~(x)dx∼∫p~(x)log⁡1p(1|x)dx+∫q(x)log⁡q(x)p(0|x)p~(x)dx


一旦成功优化,那么就有q(x,y)→p(x,y)q(x,y)→p(x,y),那么

p1p̃(x)+p0q(x)=∑yq(x,y)→∑yp(x,y)=p̃(x)(15)(15)p1p~(x)+p0q(x)=∑yq(x,y)→∑yp(x,y)=p~(x)


从而q(x)→p̃(x)q(x)→p~(x),完成了生成模型的构建。

 

现在我们优化对象有p(y|x)p(y|x)G(x)G(x),记p(1|x)=D(x)p(1|x)=D(x),这就是判别器。类似EM算法,我们进行交替优化:先固定G(z)G(z),这也意味着q(x)q(x)固定了,然后优化p(y|x)p(y|x),这时候略去常量,得到优化目标为:

D=argminD−𝔼x∼p̃(x)[logD(x)]−𝔼x∼q(x)[log(1−D(x))](16)(16)D=arg⁡minD−Ex∼p~(x)[log⁡D(x)]−Ex∼q(x)[log⁡(1−D(x))]


然后固定D(x)D(x)来优化G(x)G(x),这时候相关的loss为:

G=argminG∫q(x)logq(x)(1−D(x))p̃(x)dx(17)(17)G=arg⁡minG⁡∫q(x)log⁡q(x)(1−D(x))p~(x)dx


这里包含了我们不知道的p̃(x)p~(x),但是假如D(x)D(x)模型具有足够的拟合能力,那么跟(11)(11)式同理,D(x)D(x)的最优解应该是

D(x)=p̃(x)p̃(x)+qo(x)(18)(18)D(x)=p~(x)p~(x)+qo(x)


这里的qo(x)qo(x)是前一阶段的q(x)q(x)。从中解出p̃(x)p~(x),代入(17)(17)得到

∫q(x)logq(x)D(x)qo(x)dx==−𝔼x∼q(x)logD(x)+KL(q(x)‖qo(x))−𝔼z∼q(z)logD(G(z))+KL(q(x)‖qo(x))(19)(19)∫q(x)log⁡q(x)D(x)qo(x)dx=−Ex∼q(x)log⁡D(x)+KL(q(x)‖qo(x))=−Ez∼q(z)log⁡D(G(z))+KL(q(x)‖qo(x))

基本分析 #

可以看到,第一项就是标准的GAN生成器所采用的loss之一。

−𝔼z∼q(z)logD(G(z))(20)(20)−Ez∼q(z)log⁡D(G(z))


多出来的第二项,描述了新分布与旧分布之间的距离。这两项loss是对抗的,因为KL(q(x)‖qo(x))KL(q(x)‖qo(x))希望新旧分布尽量一致,但是如果判别器充分优化的话,对于旧分布qo(x)qo(x)中的样本,D(x)D(x)都很小(几乎都被识别为负样本),所以−logD(x)−log⁡D(x)会相当大,反之亦然。这样一来,整个loss一起优化的话,模型既要“传承”旧分布qo(x)qo(x),同时要在往新方向p(1|y)p(1|y)探索,在新旧之间插值。

 

我们知道,目前标准的GAN的生成器loss都不包含KL(q(x)‖qo(x))KL(q(x)‖qo(x)),这事实上造成了loss的不完备。假设有一个优化算法总能找到G(z)G(z)的理论最优解、并且G(z)G(z)具有无限的拟合能力,那么G(z)G(z)只需要生成唯一一个使得D(x)D(x)最大的样本(不管输入的zz是什么),这就是模型坍缩。这样说的话,理论上它一定会发生。

那么,KL(q(x)‖qo(x))KL(q(x)‖qo(x))给我们的启发是什么呢?我们设

qo(x)=qθ−Δθ(x),q(x)=qθ(x)(21)(21)qo(x)=qθ−Δθ(x),q(x)=qθ(x)


也就是说,假设当前模型的参数改变量为ΔθΔθ,那么展开到二阶得到

KL(q(x)‖qo(x))≈∫(Δθ⋅∇θqθ(x))22qθ(x)dx≈(Δθ⋅c)2(22)(22)KL(q(x)‖qo(x))≈∫(Δθ⋅∇θqθ(x))22qθ(x)dx≈(Δθ⋅c)2

我们已经指出一个完备的GAN生成器的损失函数应该要包含KL(q(x)‖qo(x))KL(q(x)‖qo(x)),如果不包含的话,那么就要通过各种间接手段达到这个效果,上述近似表明额外的损失约为(Δθ⋅c)2(Δθ⋅c)2,这就要求我们不能使得它过大,也就是不能使得ΔθΔθ过大(在每个阶段cc可以近似认为是一个常数)。而我们用的是基于梯度下降的优化算法,所以ΔθΔθ正比于梯度,因此标准GAN训练时的很多trick,比如梯度裁剪、用adam优化器、用BN,都可以解释得通了,它们都是为了稳定梯度,使得θθ不至于过大,同时,G(z)G(z)的迭代次数也不能过多,因为过多同样会导致ΔθΔθ过大。

还有,这部分的分析只适用于生成器,而判别器本身并不受约束,因此判别器可以训练到最优。

正则项 #

现在我们从中算出一些真正有用的内容,直接对KL(q(x)‖qo(x))KL(q(x)‖qo(x))进行估算,以得到一个可以在实际训练中使用的正则项。直接计算是难以进行的,但我们可以用KL(q(x,z)‖q̃(x,z))KL(q(x,z)‖q~(x,z))去估算它:

KL(q(x,z)‖q̃(x,z))===∬q(x|z)q(z)logq(x|z)q(z)q̃(x|z)q(z)dxdz∬δ(x−G(z))q(z)logδ(x−G(z))δ(x−Go(z))dxdz∫q(z)logδ(0)δ(G(z)−Go(z))dz(23)(23)KL(q(x,z)‖q~(x,z))=∬q(x|z)q(z)log⁡q(x|z)q(z)q~(x|z)q(z)dxdz=∬δ(x−G(z))q(z)log⁡δ(x−G(z))δ(x−Go(z))dxdz=∫q(z)log⁡δ(0)δ(G(z)−Go(z))dz


因为有极限

δ(x)=limσ→01(2πσ2)d/2exp(−x22σ2)(24)(24)δ(x)=limσ→01(2πσ2)d/2exp⁡(−x22σ2)


所以可以将δ(x)δ(x)看成是小方差的高斯分布,代入算得也就是我们有

KL(q(x)‖qo(x))∼λ∫q(z)‖G(z)−Go(z)‖2dz(25)(25)KL(q(x)‖qo(x))∼λ∫q(z)‖G(z)−Go(z)‖2dz


所以完整生成器的loss可以选为

𝔼z∼q(z)[−logD(G(z))+λ‖G(z)−Go(z)‖2](26)(26)Ez∼q(z)[−log⁡D(G(z))+λ‖G(z)−Go(z)‖2]


也就是说,可以用新旧生成样本的距离作为正则项,正则项保证模型不会过于偏离旧分布。

 

下面的两个在人脸数据CelebA上的实验表明这个正则项是生效的。实验代码修改自这里,目前放在我的github上。

实验一:普通的DCGAN网络,每次迭代生成器和判别器各训练一个batch。

 

不带正则项,在25个epoch之后模型开始坍缩

 

带有正则项,模型能一直稳定训练

 

实验二:普通的DCGAN网络,但去掉BN,每次迭代生成器和判别器各训练五个batch。

 

不带正则项,模型收敛速度比较慢

 

带有正则项,模型更快“步入正轨”

 

GAN相关模型 #

对抗自编码器(Adversarial Autoencoders,AAE)和对抗推断学习(Adversarially Learned Inference,ALI)这两个模型是GAN的变种之一,也可以被纳入到变分推断中。当然,有了前述准备后,这仅仅就像两道作业题罢了。

有意思的是,在ALI之中,我们有一些反直觉的结果。

GAN视角下的AAE #

事实上,只需要在GAN的论述中,将x,zx,z的位置交换,就得到了AAE的框架。

具体来说,AAE希望能训练一个编码模型p(z|x)p(z|x),来将真实分布q̃(x)q~(x)映射为标准高斯分布q(z)=N(z;0,I)q(z)=N(z;0,I),而

p(z|x)=δ(z−E(x)),p(z)=∫p(z|x)q̃(x)dx(27)(27)p(z|x)=δ(z−E(x)),p(z)=∫p(z|x)q~(x)dx


其中E(x)E(x)即为编码器的神经网络。

 

同GAN一样,AAE引入了一个二元的隐变量yy,并有

p(z,y)={p(z)p1,y=1q(z)p0,y=0(28)(28)p(z,y)={p(z)p1,y=1q(z)p0,y=0


同样直接取p1=p0=1/2p1=p0=1/2。另一方面,我们设q(z,y)=q(y|z)q(z)q(z,y)=q(y|z)q(z),这里的后验分布p(y|z)p(y|z)是一个输入为zz的二元分布,然后去优化KL(p(z,y)‖q(z,y))KL(p(z,y)‖q(z,y))

KL(p(z,y)‖q(z,y))=∼∫p(z)p1logp(z)p1q(1|z)q(z)dz+∫q(z)p0logq(z)p0q(0|z)q(z)dz∫p(z)logp(z)q(1|z)q(z)dz+∫q(z)log1q(0|z)dz(29)(29)KL(p(z,y)‖q(z,y))=∫p(z)p1log⁡p(z)p1q(1|z)q(z)dz+∫q(z)p0log⁡q(z)p0q(0|z)q(z)dz∼∫p(z)log⁡p(z)q(1|z)q(z)dz+∫q(z)log⁡1q(0|z)dz

现在我们优化对象有q(y|z)q(y|z)E(x)E(x),记q(0|z)=D(z)q(0|z)=D(z),依然交替优化:先固定E(x)E(x),这也意味着p(z)p(z)固定了,然后优化q(y|z)q(y|z),这时候略去常量,得到优化目标为:

D=argminD=argminD−𝔼z∼p(z)[log(1−D(z))]−𝔼z∼q(z)[logD(z)]−𝔼z∼p̃(x)[log(1−D(E(x)))]−𝔼z∼q(z)[logD(z)](30)(30)D=arg⁡minD−Ez∼p(z)[log⁡(1−D(z))]−Ez∼q(z)[log⁡D(z)]=arg⁡minD−Ez∼p~(x)[log⁡(1−D(E(x)))]−Ez∼q(z)[log⁡D(z)]


然后固定D(z)D(z)来优化E(x)E(x),这时候相关的loss为:

E=argminE∫p(z)logp(z)(1−D(z))q(z)dz(31)(31)E=arg⁡minE⁡∫p(z)log⁡p(z)(1−D(z))q(z)dz


利用D(z)D(z)的理论最优解D(z)=q(z)/[po(z)+q(z)]D(z)=q(z)/[po(z)+q(z)],代入loss得到

𝔼x∼p̃(x)[−logD(E(x))]+KL(p(z)‖po(z))(32)(32)Ex∼p~(x)[−log⁡D(E(x))]+KL(p(z)‖po(z))


一方面,同标准GAN一样,谨慎地训练,我们可以去掉第二项,得到

𝔼x∼p̃(x)[−logD(E(x))](33)(33)Ex∼p~(x)[−log⁡D(E(x))]


另外一方面,我们可以得到编码器后再训练一个解码器G(z)G(z),但是如果所假设的E(x),G(z)E(x),G(z)的拟合能力是充分的,重构误差可以足够小,那么将G(z)G(z)加入到上述loss中并不会干扰GAN的训练,因此可以联合训练:

G,E=argminG,E𝔼x∼p̃(x)[−logD(E(x))+λ‖x−G(E(x))‖2](34)(34)G,E=arg⁡minG,E⁡Ex∼p~(x)[−log⁡D(E(x))+λ‖x−G(E(x))‖2]

反直觉的ALI版本 #

ALI像是GAN和AAE的融合,另一个几乎一样的工作是Bidirectional GAN (BiGAN)。相比于GAN,它将zz也作为隐变量纳入到变分推断中。具体来说,在ALI中有

q(x,z,y)={p(z|x)p̃(x)p1,y=1q(x|z)q(z)p0,y=0(35)(35)q(x,z,y)={p(z|x)p~(x)p1,y=1q(x|z)q(z)p0,y=0


以及p(x,z,y)=p(y|x,z)p(z|x)p̃(x)p(x,z,y)=p(y|x,z)p(z|x)p~(x),然后去优化KL(q(x,z,y)‖p(x,z,y))KL(q(x,z,y)‖p(x,z,y))

+∬p(z|x)p̃(x)p1logp(z|x)p̃(x)p1p(1|x,z)p(z|x)p̃(x)dxdz∬q(x|z)q(z)p0logq(x|z)q(z)p0p(0|x,z)p(z|x)p̃(x)dxdz(36)(36)∬p(z|x)p~(x)p1log⁡p(z|x)p~(x)p1p(1|x,z)p(z|x)p~(x)dxdz+∬q(x|z)q(z)p0log⁡q(x|z)q(z)p0p(0|x,z)p(z|x)p~(x)dxdz


等价于最小化

∬p(z|x)p̃(x)log1p(1|x,z)dxdz+∬q(x|z)q(z)logq(x|z)q(z)p(0|x,z)p(z|x)p̃(x)dxdz(37)(37)∬p(z|x)p~(x)log⁡1p(1|x,z)dxdz+∬q(x|z)q(z)log⁡q(x|z)q(z)p(0|x,z)p(z|x)p~(x)dxdz


现在优化的对象有p(y|x,z),p(z|x),q(x|z)p(y|x,z),p(z|x),q(x|z),记p(1|x,z)=D(x,z)p(1|x,z)=D(x,z),而p(z|x)p(z|x)是一个带有编码器EE的高斯分布或狄拉克分布,q(x|z)q(x|z)是一个带有生成器GG的高斯分布或狄拉克分布。依然交替优化:先固定E,GE,G,那么与DD相关的loss为

D=argminD−𝔼x∼p̃(x),z∼p(z|x)logD(x,z)−𝔼z∼q(z),x∼q(x|z)log(1−D(x,z))(38)(38)D=arg⁡minD−Ex∼p~(x),z∼p(z|x)log⁡D(x,z)−Ez∼q(z),x∼q(x|z)log⁡(1−D(x,z))


跟VAE一样,对p(z|x)p(z|x)q(x|z)q(x|z)的期望可以通过“重参数”技巧完成。接着固定DD来优化G,EG,E,因为这时候有EE又有GG,整个loss没得化简,还是(37)(37)那样。但利用DD的最优解

D(x,z)=po(z|x)p̃(x)po(z|x)p̃(x)+qo(x|z)q(z)(39)(39)D(x,z)=po(z|x)p~(x)po(z|x)p~(x)+qo(x|z)q(z)


可以转化为

−∬p(z|x)p̃(x)logD(x,z)dxdz−∬q(x|z)q(z)logD(x,z)dxdz+∫q(z)KL(q(x|z)‖qo(x|z))dz+∬q(x|z)q(z)logpo(z|x)p(z|x)dxdz(40)(40)−∬p(z|x)p~(x)log⁡D(x,z)dxdz−∬q(x|z)q(z)log⁡D(x,z)dxdz+∫q(z)KL(q(x|z)‖qo(x|z))dz+∬q(x|z)q(z)log⁡po(z|x)p(z|x)dxdz


由于q(x|z),p(x|z)q(x|z),p(x|z)都是高斯分布,事实上后两项我们可以具体地算出来(配合重参数技巧),但同标准GAN一样,谨慎地训练,我们可以简单地去掉后面两项,得到

−∬p(z|x)p̃(x)logD(x,z)dxdz−∬q(x|z)q(z)logD(x,z)dxdz(41)(41)−∬p(z|x)p~(x)log⁡D(x,z)dxdz−∬q(x|z)q(z)log⁡D(x,z)dxdz


这就是我们导出的ALI的生成器和编码器的loss,它跟标准的ALI结果有所不同。标准的ALI(包括普通的GAN)将其视为一个极大极小问题,所以生成器和编码器的loss为

∬p(z|x)p̃(x)logD(x,z)dxdz+∬q(x|z)q(z)log(1−D(x,z))dxdz(42)(42)∬p(z|x)p~(x)log⁡D(x,z)dxdz+∬q(x|z)q(z)log⁡(1−D(x,z))dxdz


−∬p(z|x)p̃(x)log(1−D(x,z))dxdz−∬q(x|z)q(z)logD(x,z)dxdz(43)(43)−∬p(z|x)p~(x)log⁡(1−D(x,z))dxdz−∬q(x|z)q(z)log⁡D(x,z)dxdz


它们都不等价于(41)(41)。针对这个差异,事实上笔者也做了实验,结果表明这里的ALI有着和标准的ALI同样的表现,甚至可能稍好一些(可能是我的自我良好的错觉,所以就没有放图了)。这说明,将对抗网络视为一个极大极小问题仅仅是一个直觉行为,并非总应该如此。

 

结论综述 #

本文的结果表明了变分推断确实是一个推导和解释生成模型的统一框架,包括VAE和GAN。通过变分推断的新诠释,我们介绍了变分推断是如何达到这个目的的。

当然,本文不是第一篇提出用变分推断研究GAN这个想法的文章。在《On Unifying Deep Generative Models》一文中,其作者也试图用变分推断统一VAE和GAN,也得到了一些启发性的结果。但笔者觉得那不够清晰。事实上,我并没有完全读懂这篇文章,我不大确定,这篇文章究竟是将GAN纳入到了变分推断中了,还是将VAE纳入到了GAN中~相对而言,我觉得本文的论述更加清晰、明确一些。

看起来变分推断还有很大的挖掘空间,等待着我们去探索。

转载到请包括本文地址:https://spaces.ac.cn/archives/5716

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值