(简单易懂)Variational Inference 变分推理

Variational Inference 变分推理

1.AE与VAE

在这里插入图片描述

Figure 1: Auto-Encoder.

​ 如Fig.1,作为经典的网络结构之一,Auto-Encoder在深度学习的多个领域中有着出色的表现,依赖于重构误差的反向优化,模型可以学习到数据在低维空间的表示,过滤掉数据的冗余特征,得到细粒度特征。但是,AE学习的是确定性编码(潜在表示经常是确定性的),而不是数据的概率分布,不能直接提供关于数据分布的信息,可能会产生过拟合等问题。

在这里插入图片描述

Figure 2.

​ 如Fig.2,为了提升模型的泛化能力,使模型具有理解数据分布和生成新样本的能力,Variational Auto-Encoder横空出世。与AE结构不同的是,VAE的潜在表示 z z z 的生成方式由分布采样得到。通过编码器得到潜在分布的均值 μ \mu μ 与标准差 σ \sigma σ 。在此基础上,加入服从正态分布的噪声 ϵ ∼ N ( 0 , 1 ) \epsilon\sim\mathcal{N}(0,1) ϵN(0,1) ,使得 q ( z ∣ x ) = N ( μ , e x p ( σ ) 2 ) q(z|x)=\mathcal{N}(\mu,{exp(\sigma)}^2) q(zx)=N(μ,exp(σ)2) ,同时令 z = μ + e x p ( σ ) × ϵ z=\mu+exp(\sigma)\times\epsilon z=μ+exp(σ)×ϵ 。其中 e x p ( σ ) > 0 exp(\sigma)>0 exp(σ)>0 相当于噪声强度因子,且噪声的添加使得模型更具有抗扰动能力。

​ 在loss function的约束项上,也要最小化 ∑ ( e x p ( σ ) − ( 1 + σ ) + μ 2 ) \sum(exp(\sigma)-(1+\sigma)+\mu^2) (exp(σ)(1+σ)+μ2) ,令 σ \sigma σ 趋近于0, μ \mu μ 趋近于0,故 e x p ( σ ) exp(\sigma) exp(σ) 趋近于1。由此,使得 q ( z ∣ x ) q(z|x) q(zx) 趋近于标准正态分布。注:尽管不同数据训练出的变分自编码器(VAE)可能会使潜在分布接近于标准正态分布,但由于模型内部参数的不同和数据的特征差异,即使手动将标准正态分布生成的潜在变量 z ′ z' z​ 输入到不同数据训练出的解码器中,不同解码器所生成的数据之间也会有很大的差异。

2.变分推理的数学推导

​ 虽然我们通过神经网络得到了 q ( z ∣ x ) = N ( μ , e x p ( σ ) 2 ) q(z|x)=\mathcal{N}(\mu,{exp(\sigma)}^2) q(zx)=N(μ,exp(σ)2) ,但是 q ( z ∣ x ) q(z|x) q(zx) 就是由 x x x 得到的 z z z 的真实分布吗?定义由 x x x 得到的 z z z 的真实分布为 p ( z ∣ x ) p(z|x) p(zx) ,我们通过神经网络学习得到 q ( z ∣ x ) q(z|x) q(zx) 来近似 p ( z ∣ x ) p(z|x) p(zx) ,确保神经网络生成分布的准确性。我们假设先验分布 p ( z ) p(z) p(z) 与 后验分布 p ( z ∣ x ) p(z|x) p(zx) 为正态分布(因为计算KL等公式的便利性、标准正态分布被认为是一种无信息性先验、标准正态分布在潜在空间的均匀性等等),因此,为了拉近 q ( z ∣ x ) q(z|x) q(zx) p ( z ∣ x ) p(z|x) p(zx)​ 我们使用KL散度(Kullback-Leibler Divergence)度量两个概率分布之间的差异程度(离散型、连续型):

离散型
D K L ( P ∥ Q ) = ∑ i = 1 n P i l o g ( P i Q i ) (1) D_{KL}(P\parallel Q)=\sum_{i=1}^n P_ilog(\frac{P_i}{Q_i}) \tag{1} DKL(PQ)=i=1nPilog(QiPi)(1)
其中 P , Q P,Q P,Q 为离散型随机变量的概率分布律

连续型
D K L ( P ∥ Q ) = ∫ − ∞ + ∞ p ( x ) l o g p ( x ) q ( x ) d x (2) D_{KL}\left(P\parallel Q\right)=\int_{-\infty}^{+\infty}p(x)log\frac{p(x)}{q(x)}dx \tag{2} DKL(PQ)=+p(x)logq(x)p(x)dx(2)
其中 P , Q P,Q P,Q​ 为连续型随机变量的概率密度

​ 因此,我们用KL散度衡量 q ( z ∣ x ) q(z|x) q(zx) p ( z ∣ x ) p(z|x) p(zx) ,最小化下式:
K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) = ∫ q ( z ∣ x ) log ⁡ q ( z ∣ x ) p ( z ∣ x ) d z (3) KL(q(z\mid x)||p(z\mid x))=\int q(z\mid x)\log\frac{q(z\mid x)}{p(z\mid x)}dz \tag{3} KL(q(zx)∣∣p(zx))=q(zx)logp(zx)q(zx)dz(3)
根据贝叶斯定理:
p ( z ∣ x ) = p ( x ∣ z ) p ( z ) p ( x ) (4) p(z\mid x)=\frac{p(x\mid z)p(z)}{p(x)} \tag{4} p(zx)=p(x)p(xz)p(z)(4)
原式等价于:
K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) = ∫ q ( z ∣ x ) log ⁡ q ( z ∣ x ) p ( x ∣ z ) p ( z ) p ( x ) d z = ∫ q ( z ∣ x ) log ⁡ q ( z ∣ x ) d z + ∫ q ( z ∣ x ) log ⁡ p ( x ) d z − ∫ q ( z ∣ x ) log ⁡ [ p ( x ∣ z ) p ( z ) ] d z = ∫ q ( z ∣ x ) log ⁡ q ( z ∣ x ) d z + log ⁡ p ( x ) ∫ q ( z ∣ x ) d z − ∫ q ( z ∣ x ) log ⁡ [ p ( x ∣ z ) p ( z ) ] d z   ( 其中 : ∫ q ( z ∣ x ) d z = 1 ) = log ⁡ p ( x ) + ∫ q ( z ∣ x ) log ⁡ q ( z ∣ x ) d z − ∫ q ( z ∣ x ) log ⁡ [ p ( x ∣ z ) p ( z ) ] d z = log ⁡ p ( x ) + ∫ q ( z ∣ x ) log ⁡ q ( z ∣ x ) d z − ∫ q ( z ∣ x ) log ⁡ [ p ( x ∣ z ) p ( z ) ] d z = log ⁡ p ( x ) + ∫ q ( z ∣ x ) log ⁡ q ( z ∣ x ) d z − ∫ q ( z ∣ x ) log ⁡ p ( x ∣ z ) d z − ∫ q ( z ∣ x ) log ⁡ p ( z ) d z = log ⁡ p ( x ) + ∫ q ( z ∣ x ) log ⁡ q ( z ∣ x ) p ( z ) d z − ∫ q ( z ∣ x ) log ⁡ p ( x ∣ z ) d z = log ⁡ p ( x ) − E z ∼ q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] + D K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) \begin{align*} KL(q(z\mid x)||p(z\mid x))&=\int q(z\mid x)\log\frac{q(z\mid x)}{\frac{p(x\mid z)p(z)}{p(x)}}dz \\ &=\int q(z\mid x)\log q(z\mid x)dz+\int q(z\mid x)\log p(x)dz-\int q(z\mid x)\log[p(x\mid z)p(z)]dz \\ &=\int q(z\mid x)\log q(z\mid x)dz+\log p(x)\int q(z\mid x)dz-\int q(z\mid x)\log[p(x\mid z)p(z)]dz \ (其中:\int q(z\mid x)dz=1) \\ &=\log p(x)+\int q(z\mid x)\log q(z\mid x)dz-\int q(z\mid x)\log[p(x\mid z)p(z)]dz \\ &=\log p(x)+\int q(z\mid x)\log q(z\mid x)dz-\int q(z\mid x)\log[p(x\mid z)p(z)]dz \\ &=\log p(x)+\int q(z\mid x)\log q(z\mid x)dz-\int q(z\mid x)\log p(x\mid z)dz-\int q(z\mid x)\log p(z)dz \\ &=\log p(x)+\int q(z\mid x)\log\frac{q(z\mid x)}{p(z)}dz-\int q(z\mid x) \log p(x\mid z)dz \\ &=\log p(x)-E_{z\sim q(z\mid x)}\left[\log p(x\mid z)\right]+D_{KL}\left(q(z\mid x)||p(z)\right) \tag{5} \end{align*} KL(q(zx)∣∣p(zx))=q(zx)logp(x)p(xz)p(z)q(zx)dz=q(zx)logq(zx)dz+q(zx)logp(x)dzq(zx)log[p(xz)p(z)]dz=q(zx)logq(zx)dz+logp(x)q(zx)dzq(zx)log[p(xz)p(z)]dz (其中:q(zx)dz=1)=logp(x)+q(zx)logq(zx)dzq(zx)log[p(xz)p(z)]dz=logp(x)+q(zx)logq(zx)dzq(zx)log[p(xz)p(z)]dz=logp(x)+q(zx)logq(zx)dzq(zx)logp(xz)dzq(zx)logp(z)dz=logp(x)+q(zx)logp(z)q(zx)dzq(zx)logp(xz)dz=logp(x)Ezq(zx)[logp(xz)]+DKL(q(zx)∣∣p(z))(5)
由于z是从分布中进行采样得到的,而采样过程是不可导的,而我们需要梯度的反传优化,为了将 Eq. (5) E z ∼ q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] = ∫ q ( z ∣ x ) log ⁡ p ( x ∣ z ) d z E_{z\sim q(z\mid x)}\left[\log p(x\mid z)\right]=\int q(z\mid x) \log p(x\mid z)dz Ezq(zx)[logp(xz)]=q(zx)logp(xz)dz 中的 z z z 消掉,我们使用重参数化技巧(例:给定 Z ∼ N ( μ , σ 2 ) Z\sim\mathcal{N}(\mu,\sigma^2) ZN(μ,σ2) ϵ ∼ N ( 0 , I ) \epsilon\sim\mathcal{N}(0,\mathbf{I}) ϵN(0,I) 故将 Z Z Z 转化为 Z = μ + σ ϵ Z=\mu+\sigma\epsilon Z=μ+σϵ)将对 z z z 的采样等价于对其分布的均值,标准差的采样。

根据 Eq. (5),我们进行拆分:
D K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) = ∫ q ( z ∣ x ) log ⁡ q ( z ∣ x ) p ( z ) d z = ∫ q ( z ∣ x ) log ⁡ q ( z ∣ x ) d z − ∫ q ( z ∣ x ) log ⁡ p ( z ) d z = ∫ N ( z ; μ , σ 2 ) log ⁡ N ( z ; μ , σ 2 ) d z − ∫ N ( z ; μ , σ 2 ) log ⁡ N ( z ; 0 , I ) d z = − J 2 log ⁡ ( 2 π ) − 1 2 ∑ j = 1 J ( 1 + log ⁡ σ j 2 ) − ( − J 2 log ⁡ ( 2 π ) − 1 2 ∑ j = 1 J ( μ j 2 + σ j 2 ) ) = 1 2 ∑ j = 1 J ( ( μ j ) 2 + ( σ j ) 2 − 1 − log ⁡ ( ( σ j ) 2 ) ) \begin{align*} D_{KL}\left(q(z\mid x)||p(z)\right)&=\int q(z\mid x)\log\frac{q(z\mid x)}{p(z)}dz \\ &=\int q(z\mid x)\log q(z\mid x)dz-\int q(z\mid x)\log p(z)dz \\ &=\int\mathcal{N}\left(\mathbf{z};\boldsymbol{\mu},\boldsymbol{\sigma}^2\right)\log\mathcal{N}\left(\mathbf{z};\boldsymbol{\mu},\boldsymbol{\sigma}^2\right)d\mathbf{z}-\int\mathcal{N}\left(\mathbf{z};\boldsymbol{\mu},\boldsymbol{\sigma}^2\right)\log\mathcal{N}(\mathbf{z};\boldsymbol{0},\mathbf{I})d\mathbf{z} \\ &=-\frac J2\log(2\pi)-\frac12\sum_{j=1}^J\left(1+\log\sigma_j^2\right)-(-\frac J2\log(2\pi)-\frac12\sum_{j=1}^J\left(\mu_j^2+\sigma_j^2\right)) \\ &=\frac12\sum_{j=1}^J\left(\left(\mu_j\right)^2+\left(\sigma_j\right)^2-1-\log\left(\left(\sigma_j\right)^2\right)\right) \tag{6} \end{align*} DKL(q(zx)∣∣p(z))=q(zx)logp(z)q(zx)dz=q(zx)logq(zx)dzq(zx)logp(z)dz=N(z;μ,σ2)logN(z;μ,σ2)dzN(z;μ,σ2)logN(z;0,I)dz=2Jlog(2π)21j=1J(1+logσj2)(2Jlog(2π)21j=1J(μj2+σj2))=21j=1J((μj)2+(σj)21log((σj)2))(6)
其中 log ⁡ p ( x ) − E z ∼ q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] \log p(x)-E_{z\sim q(z\mid x)}\left[\log p(x\mid z)\right] logp(x)Ezq(zx)[logp(xz)] 等价于MSE(或其他Loss,代表真实值与预测值的损失):
log ⁡ p ( x ) − E z ∼ q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] → M S E = 1 n ∑ i − 1 n ( x i − y i ) 2 \begin{align*} \log p(x)-E_{z\sim q(z\mid x)}\left[\log p(x\mid z)\right]\to MSE=\frac1n\sum_{i-1}^{n}{(x_i-y_i)^2} \tag{7} \end{align*} logp(x)Ezq(zx)[logp(xz)]MSE=n1i1n(xiyi)2(7)
因此 L L L 如下:
L = 1 n ∑ i − 1 n ( x i − y i ) 2 + 1 2 ∑ j = 1 J ( ( μ j ) 2 + ( σ j ) 2 − 1 − log ⁡ ( ( σ j ) 2 ) ) (8) L=\frac1n\sum_{i-1}^{n}{(x_i-y_i)^2}+\frac12\sum_{j=1}^J\left(\left(\mu_j\right)^2+\left(\sigma_j\right)^2-1-\log\left(\left(\sigma_j\right)^2\right)\right) \tag{8} L=n1i1n(xiyi)2+21j=1J((μj)2+(σj)21log((σj)2))(8)
​ 综上所述,我们通过最小化 L L L 使神经网络学习得到的 q ( z ∣ x ) q(z|x) q(zx) 来近似真实分布 p ( z ∣ x ) p(z|x) p(zx) ,使自编码器具有泛化性和生成新样本的能力。

3.参考

VAE原文

VAE(变分自编码器) 详解

变分自编码器VAE ——公式推导(含实现代码)

VAE模型(Variational Autoencoders)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值