极简笔记 VAE(变分自编码器)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/Hibercraft/article/details/80457445

极简笔记 VAE(变分自编码器)

论文原文:Auto-Encoding Variational Bayes

这是一篇极其拗口的文章,但是文章从变分推断一路延伸到自编码器的构造,过程一气呵成,和当下DL领域的灌水之风形成鲜明对比,是难得的佳作。为了能够从理论到实现融会贯通地理解,本篇笔记会更加偏向于思路解读而非原文复述。

VAE是一个生成模型,对于生成模型,我们希望求得的都是原始数据分布p(x)。但是我们有的只是离散的对真实分布的采样{x1,x2,x3,...,xn},这就是我们的数据集。这时候我们常常会先假设分布的类型(e.g. 高斯分布,均匀分布),然后用最大似然(ML)来做,计算参数θ=argmaxθilogp(xi),从而求得pθ(x)。但是p(x)可能是一个很复杂的难以表达的分布,难以选择合适的假设分布来计算最大似然。

于是假设除了可见变量x,还存在着潜变量z,且z满足某个简单的分布p(z)。那么原始数据分布可以按照后验概率公式分解成p(x)=p(x|z)p(z)dz,但是在离散数据情况下边缘化z是非常费时的,于是我们想到了变分推断。

logpθ(xi)有下界L(θ,ϕ;xi), 这个下界可以进一步拆分:

L(θ,ϕ;xi)=DKL(qϕ(z|xi)||pθ(z))+Eqϕ(z|xi)[logpθ(xi|z)]

这里非常重要,为了提升下界,等价于减小第一项的KL散度,增加第二项的期望。那么第一项可以看做是对qϕ(z|xi)的正则项,使其更加接近于先验分布pθ(z);第二项可以看做是负重构误差项,当z满足q分布时,logpθ(xi|z)越来越接近logpθ(xi)(输出越来越接近输入)。因为别忘了L(θ,ϕ;xi)就是logpθ(xi)的下界啊,前面KL散度趋近于0,下界就只有后一项了,又要接近于原来的logpθ(xi),所以可以认为要求logpθ(xi|z)越来越接近logpθ(xi)

提到重构误差有没有想起啥?就是auto-encoder!如果对原始的自编码器中间的潜变量z加上上文的正则化约束,就建立了理论到模型的桥梁!既然这么相似,那么我们确认一个目标,就是要用AE来完成下界L(θ,ϕ;xi)的提升。

其实到目前为止理论和模型的桥梁并没有完全打通,因为AE作为神经网络靠随机梯度下降可以做到最优化某个函数。且AE对于一个确定的输入xi,只会有一个确定的z的分布(即分布参数是确定的)产生,而上面这个下界L(θ,ϕ;xi)里第二项还有对z采样的部分。这里文章提出了一个神奇的trick——重参数化(reparameterization)。

在重参数化之前,回顾一下初衷,即得到p(x)。但p(x)分布可能太复杂,希望有一个简单的潜变量分布p(z),通过迂回的方式学到p(x)。那么我们就假设这个简单的潜变量分布的先验是个标准正态分布zN(0,1)。同时中间的正则化项不是要让qϕ(z|xi)逼近先验pθ(z)嘛,如果不是同一个分布簇的话怎么逼近(文章提到有相关证明,KL散度逼近0必定是同分布簇),所以假设分布qϕ(z|xi)也是高斯分布,但是一开始他的参数ϕ(包含期望和方差)并不是0和1,而是μiσi2。那么下界第一项的KL散度就可以简化成

KL(N(μ,σ2)||N(0,1))

=12πσ2e(xμ)2/2σ2(loge(xμ)2/2σ2/2πσ2ex2/2σ2/2π)dx

=12πσ2e(xμ)2/2σ2log(1σ2exp(12[x2(xμ)2/σ2]))dx

=1212πσ2e(xμ)2/2σ2[logσ2+x2(xμ)2/σ2]dx

=12(logσ2+μ2+σ21)

从这个简化结果可以看出,AE的encoder可以回归μσ2,然后通过计算上式进行约束即可。

在训练过程中,AE的decoder需要对zpθ(z|xi)=N(μi,σi2)进行采样,这里文章对z进行重参数化操作使得z=gϕ(ϵ,x),其中ϵN(0,1), 于是对z采样等价于对μi+ϵσiN(μi,σi2),两者的分布是完全相同的。μi,σi对每个xi是确定的,引入的ϵ在AE的支路上,在BP时不会对encoder部分产生影响。
图片来源见[2]

采样的问题解决了,相应的decoder部分化为

Eqϕ(z|xi)[logpθ(xi|z)]=1Ll=1L(logpθ(xi|zi,l))

where zi,l=gϕ(ϵi,l,xi)and ϵlp(ϵ)

最终VAE的loss,其中j表示向量的各个分量索引,表示逐元素乘法:
L(θ,ϕ;xi)12j=1J(1+log((σij)2)(μij)2(σij)2)+1Ll=1Llogpθ(xi|zi,l)

where zi,l=μi+σiϵl,ϵlN(0,I)

从变分推断理论到VAE网络模型的路径完全打通,完结撒花!

参考阅读:
[1]https://kexue.fm/archives/5253
[2]https://zhuanlan.zhihu.com/p/25401928

展开阅读全文

没有更多推荐了,返回首页