模型学习 - VAE(变分自编码)专题

       理解出错之处望不吝指正。
       今天听师兄给我们讲了VAE,觉得颇有收获,分享一下,希望大家批评指正。

生成模型

       生成模型的目的是从一系列样本 x = { x 1 , x 2 , . . . , x m } x=\{x_1,x_2,...,x_m\} x={x1,x2,...,xm}中学习 x x x的分布 p ( x ) p(x) p(x),我们可以仿照EM算法,通过隐变量 z z z和生成函数 g ( ) g() g()来得到 x ^ = g ( z ) \hat x=g(z) x^=g(z),并尽可能的让 x ^ \hat {x} x^接近 x x x
       上述方法有一个弊端,我们首先依据全概率公式将 p ( x ) p(x) p(x)写成如下形式:
p ( x ) = ∑ p ( x ∣ z ) p ( z ) p(x)=\sum p(x|z)p(z) p(x)=p(xz)p(z)
       我们易获取到 p ( z ) p(z) p(z),但是在 z z z确定的情况下,无法得知 p ( x ∣ z ) p(x|z) p(xz),这意味着我们无法将 “利用 g ( ) g() g()函数生成的 x ^ i \hat x_i x^i 与 “真实的 x i x_i xi 进行对应,如下图所示:
       在这里插入图片描述
       那么怎么解决这个问题呢,接着往下看~

VAE模型

       在VAE模型中,我们解决了上述问题。我们可以从两个角度理解VAE模型,首先是角度1。

理解角度1

       为了解决上述问题,将学习 p ( x ) p(x) p(x)改为学习 p ( z ∣ x ) p(z|x) p(zx)即可!
       我们令 p ( z ∣ x ) ∼ N ( μ , σ 2 ) p(z|x)\sim N(\mu,\sigma^2) p(zx)N(μ,σ2)(ps:为什么这个先验分布是正态分布呢?因为若是其他分布,后面计算KL散度的时候会导致分母为0)。则VAE的编码解码过程如下图所示:
       在这里插入图片描述
       更具体的,我们假设 p ( z ∣ x ) ∼ N ( 0 , 1 ) p(z|x)\sim N(0,1) p(zx)N(0,1),即:我们在编码过程中期望学到 μ i = 0 , σ i = 1 \mu_i=0,\sigma_i=1 μi=0,σi=1,则VAE的训练过程中会产生以下两种情况(类似于对抗):
       1.若 σ i 2 → 1 \sigma_i^2\rightarrow1 σi21,此时加在 x i x_i xi上的噪声大,会导致已有的解码能力效果变差,通过反向传播会使得 σ i 2 → 0 \sigma_i^2\rightarrow0 σi20
       2.若 σ i 2 → 0 \sigma_i^2\rightarrow0 σi20,此时加在 x i x_i xi上的噪声小,会导致已有的解码能力效果变好,通过反向传播会使得 σ i 2 → 1 \sigma_i^2\rightarrow1 σi21
       
       VAE模型的损失函数如下:
L v a e p ( z ∣ x ) = l o s s ( x , x ^ ) + K L [ N ( μ , σ 2 ) ∣ ∣ N ( 0 , 1 ) ] L_{vae}^{p(z|x)}=loss(x,\hat x)+KL[N(\mu,\sigma^2)||N(0,1)] Lvaep(zx)=loss(x,x^)+KL[N(μ,σ2)N(0,1)]
       损失函数中第一部分 l o s s ( x , x ^ ) loss(x,\hat x) loss(x,x^)代表真实数据 x x x与生成数据 x ^ \hat x x^之间的误差,可以使用简单的logistics损失或者MSE损失。
       损失函数中第二部分是一个KL散度,用于衡量编码过程中得到的分布是否接近我们设置的先验分布 N ( 0 , 1 ) N(0,1) N(0,1),下面对这部分进行详细的剖析。

                K L [ N ( μ , σ 2 ) ∣ ∣ N ( 0 , 1 ) ] KL[N(\mu,\sigma^2)||N(0,1)] KL[N(μ,σ2)N(0,1)]

           = ∫ 1 2 π σ e x p { − ( x − μ ) 2 2 σ 2 } ⋅ log ⁡ 1 2 π σ e x p { − ( x − μ ) 2 2 σ 2 } 1 2 π e x p { − x 2 2 } d x =\int \frac{1}{\sqrt{2\pi}\sigma}exp\{\frac{-(x-\mu)^2}{2\sigma^2}\}\centerdot\log \frac{\frac{1}{\sqrt{2\pi}\sigma}exp\{\frac{-(x-\mu)^2}{2\sigma^2}\}}{\frac{1}{\sqrt{2\pi}}exp\{\frac{-x^2}{2}\}}dx =2π σ1exp{2σ2(xμ)2}log2π 1exp{2x2}2π σ1exp{2σ2(xμ)2}dx

           = 1 2 ∫ 1 2 π σ e x p { − ( x − μ ) 2 2 σ 2 } ⋅ [ − log ⁡ σ 2 + x 2 − ( x − μ ) 2 σ 2 ] d x =\frac{1}{2}\int \frac{1}{\sqrt{2\pi}\sigma}exp\{\frac{-(x-\mu)^2}{2\sigma^2}\}\centerdot[-\log \sigma^2+x^2-\frac{(x-\mu)^2}{\sigma^2}]dx =212π σ1exp{2σ2(xμ)2}[logσ2+x2σ2(xμ)2]dx

           = 1 2 ( − log ⁡ σ 2 + μ 2 + σ 2 − 1 ) =\frac{1}{2}(-\log \sigma^2+\mu^2+\sigma^2-1) =21(logσ2+μ2+σ21)

           = 1 2 μ 2 + 1 2 ( σ 2 − log ⁡ σ 2 − 1 ) =\frac{1}{2}\mu^2+\frac{1}{2}(\sigma^2-\log \sigma^2-1) =21μ2+21(σ2logσ21)

理解角度2

       在理解角度1中,我们由于不能计算 p ( x ∣ z ) p(x|z) p(xz),所以将学习 p ( x ) p(x) p(x)改为了学习 p ( z ∣ x ) p(z|x) p(zx),但是我们忽略了 p ( x ) p(x) p(x)不仅可以使用全概率公式分解为 p ( x ) = ∑ p ( x ∣ z ) p ( z ) p(x)=\sum p(x|z)p(z) p(x)=p(xz)p(z),还可将 p ( x ) p(x) p(x)写为 p ( x ) = ∫ p ( x , z ) d z p(x)=\int p(x,z)dz p(x)=p(x,z)dz,在这种情况下,我们假设先验分布为 q ( x , z ) q(x,z) q(x,z),则我们的学习目标变为:令 p ( x , z ) p(x,z) p(x,z)无限趋近 q ( x , z ) q(x,z) q(x,z),如下所示:

                K L [ p ( x , z ) ∣ ∣ q ( x , z ) ] KL[p(x,z)||q(x,z)] KL[p(x,z)q(x,z)]

           = ∫ ∫ p ( x , z ) log ⁡ p ( x , z ) q ( x , z ) d z d x =\int\int p(x,z)\log \frac{p(x,z)}{q(x,z)} dz dx =p(x,z)logq(x,z)p(x,z)dzdx

          将 p ( x , z ) = p ^ ( x ) p ( z ∣ x ) p(x,z)=\hat p(x)p(z|x) p(x,z)=p^(x)p(zx)带入上式,其中 p ^ ( x ) \hat p(x) p^(x)代表利用已有的 x x x值通过估计得到的分布,可得:

           = ∫ p ^ ( x ) [ ∫ p ( z ∣ x ) log ⁡ p ^ ( x ) p ( z ∣ x ) q ( x , z ) d z ] d x =\int \hat p(x)[ \int p(z|x) \log \frac{\hat p(x)p(z|x)}{q(x,z)}dz] dx =p^(x)[p(zx)logq(x,z)p^(x)p(zx)dz]dx

           = E x ∼ p ^ ( x ) [ ∫ p ( z ∣ x ) log ⁡ p ^ ( x ) p ( z ∣ x ) q ( x , z ) d z ] =E_{x\sim\hat p(x)}[\int p(z|x) \log \frac {\hat p(x)p(z|x)}{q(x,z)}dz] =Exp^(x)[p(zx)logq(x,z)p^(x)p(zx)dz]

          将 q ( x , z ) = q ( z ) q ( x ∣ z ) q(x,z)=q(z)q(x|z) q(x,z)=q(z)q(xz) p ( z ∣ x ) log ⁡ p ( z ∣ x ) q ( z ) = K L [ p ( z ∣ x ) ∣ ∣ q ( z ) ] p(z|x)\log \frac {p(z|x)}{q(z)}=KL[p(z|x)||q(z)] p(zx)logq(z)p(zx)=KL[p(zx)q(z)]带入上式,可得:

           = E x ∼ p ^ ( x ) { E z ∼ p ( z ∣ x ) [ − log ⁡ q ( x ∣ z ) ] + K L [ p ( z ∣ x ) ∣ ∣ q ( z ) ] } =E_{x\sim\hat p(x)}\{E_{z\sim p(z|x)}[-\log q(x|z)] + KL[p(z|x)||q(z)]\} =Exp^(x){Ezp(zx)[logq(xz)]+KL[p(zx)q(z)]}


       我们可以发现,这样计算得到的公式中每一项可以和 L v a e p ( z ∣ x ) L_{vae}^{p(z|x)} Lvaep(zx)中的每一项对应:

           − log ⁡ q ( x ∣ z ) ⟷ l o s s ( x , x ^ ) -\log q(x|z)\longleftrightarrow loss(x,\hat x) logq(xz)loss(x,x^)

           K L [ p ( z ∣ x ) ∣ ∣ q ( z ) ] ⟷ K L [ N ( μ , σ 2 ) ∣ ∣ N ( 0 , 1 ) ] KL[p(z|x)||q(z)]\longleftrightarrow KL[N(\mu,\sigma^2)||N(0,1)] KL[p(zx)q(z)]KL[N(μ,σ2)N(0,1)]


       原文中说 q ( x ∣ z ) q(x|z) q(xz)可以为两种分布:1.伯努利分布(B);2.正态分布(N)
       1.当 q ( x ∣ z ) ∼ B q(x|z)\sim B q(xz)B时,容易看出这是交叉熵损失:
− log ⁡ q ( x ∣ z ) = ∑ { x k log ⁡ p k ( z ) + ( 1 − x k ) log ⁡ [ 1 − p k ( z ) ] } -\log q(x|z) = \sum \{x_k\log p_k(z)+(1-x_k)\log [1-p_k(z)]\} logq(xz)={xklogpk(z)+(1xk)log[1pk(z)]}
       2.当 q ( x ∣ z ) ∼ N q(x|z)\sim N q(xz)N时,容易看出这是均方误差损失:
− log ⁡ q ( x ∣ z ) = 1 2 σ 2 ∣ ∣ x − μ k ∣ ∣ 2 -\log q(x|z) = \frac {1}{2\sigma_2}||x-\mu_k||^2 logq(xz)=2σ21xμk2

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值