VI变分推断

本文介绍了变分推断中的EM算法和平均场近似,着重讲解了SGVI中如何通过梯度上升优化变分分布。讨论了平均场假设的局限性,并提出利用蒙特卡洛采样和Reparameterization技术降低梯度估计方差。
摘要由CSDN通过智能技术生成

变分推断

我们已经知道概率模型可以分为,频率派的优化问题和贝叶斯派的积分问题。从贝叶斯角度来看推断,对于 x ^ \hat{x} x^ 这样的新样本,需要得到:
p ( x ^ ∣ X ) = ∫ θ p ( x ^ , θ ∣ X ) d θ = ∫ θ p ( θ ∣ X ) p ( x ^ ∣ θ , X ) d θ p(\hat{x}|X)=\int_\theta p(\hat{x},\theta|X)d\theta=\int_\theta p(\theta|X)p(\hat{x}|\theta,X)d\theta p(x^X)=θp(x^,θX)dθ=θp(θX)p(x^θ,X)dθ
如果新样本和数据集独立,那么推断就是概率分布依参数后验分布的期望。

我们看到,推断问题的中心是参数后验分布的求解,推断分为:

  1. 精确推断
  2. 近似推断-参数空间无法精确求解
    1. 确定性近似-如变分推断
    2. 随机近似-如 MCMC,MH,Gibbs

基于平均场假设的变分推断

我们记 Z Z Z 为隐变量和参数的集合, Z i Z_i Zi 为第 i i i 维的参数,于是,回顾一下 EM 中的推导:
log ⁡ p ( X ) = log ⁡ p ( X , Z ) − log ⁡ p ( Z ∣ X ) = log ⁡ p ( X , Z ) q ( Z ) − log ⁡ p ( Z ∣ X ) q ( Z ) \log p(X)=\log p(X,Z)-\log p(Z|X)=\log\frac{p(X,Z)}{q(Z)}-\log\frac{p(Z|X)}{q(Z)} logp(X)=logp(X,Z)logp(ZX)=logq(Z)p(X,Z)logq(Z)p(ZX)
左右两边分别积分:
L e f t : ∫ Z q ( Z ) log ⁡ p ( X ) d Z = log ⁡ p ( X ) R i g h t : ∫ Z [ log ⁡ p ( X , Z ) q ( Z ) − log ⁡ p ( Z ∣ X ) q ( Z ) ] q ( Z ) d Z = E L B O + K L ( q , p ) Left:\int_Zq(Z)\log p(X)dZ=\log p(X)\\ Right:\int_Z[\log \frac{p(X,Z)}{q(Z)}-\log \frac{p(Z|X)}{q(Z)}]q(Z)dZ=ELBO+KL(q,p) Left:Zq(Z)logp(X)dZ=logp(X)Right:Z[logq(Z)p(X,Z)logq(Z)p(ZX)]q(Z)dZ=ELBO+KL(q,p)
第二个式子可以写为变分和 KL 散度的和:
L ( q ) + K L ( q , p ) L(q)+KL(q,p) L(q)+KL(q,p)
由于这个式子是常数,于是寻找 q ≃ p q\simeq p qp 就相当于对 L ( q ) L(q) L(q) 最大值。
q ^ ( Z ) = a r g m a x q ( Z ) L ( q ) \hat{q}(Z)=\mathop{argmax}_{q(Z)}L(q) q^(Z)=argmaxq(Z)L(q)
假设 q ( Z ) q(Z) q(Z) 可以划分为 M M M 个组(平均场近似):
q ( Z ) = ∏ i = 1 M q i ( Z i ) q(Z)=\prod\limits_{i=1}^Mq_i(Z_i) q(Z)=i=1Mqi(Zi)
因此,在 L ( q ) = ∫ Z q ( Z ) log ⁡ p ( X , Z ) d Z − ∫ Z q ( Z ) log ⁡ q ( Z ) L(q)=\int_Zq(Z)\log p(X,Z)dZ-\int_Zq(Z)\log{q(Z)} L(q)=Zq(Z)logp(X,Z)dZZq(Z)logq(Z) 中,看 p ( Z j ) p(Z_j) p(Zj) ,第一项:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲\int_Zq(Z)\log …

第二项:
∫ Z q ( Z ) log ⁡ q ( Z ) d Z = ∫ Z ∏ i = 1 M q i ( Z i ) ∑ i = 1 M log ⁡ q i ( Z i ) d Z \int_Zq(Z)\log q(Z)dZ=\int_Z\prod\limits_{i=1}^Mq_i(Z_i)\sum\limits_{i=1}^M\log q_i(Z_i)dZ Zq(Z)logq(Z)dZ=Zi=1Mqi(Zi)i=1Mlogqi(Zi)dZ
展开求和项第一项为:
∫ Z ∏ i = 1 M q i ( Z i ) log ⁡ q 1 ( Z 1 ) d Z = ∫ Z 1 q 1 ( Z 1 ) log ⁡ q 1 ( Z 1 ) d Z 1 \int_Z\prod\limits_{i=1}^Mq_i(Z_i)\log q_1(Z_1)dZ=\int_{Z_1}q_1(Z_1)\log q_1(Z_1)dZ_1 Zi=1Mqi(Zi)logq1(Z1)dZ=Z1q1(Z1)logq1(Z1)dZ1
所以:
∫ Z q ( Z ) log ⁡ q ( Z ) d Z = ∑ i = 1 M ∫ Z i q i ( Z i ) log ⁡ q i ( Z i ) d Z i = ∫ Z j q j ( Z j ) log ⁡ q j ( Z j ) d Z j + C o n s t \int_Zq(Z)\log q(Z)dZ=\sum\limits_{i=1}^M\int_{Z_i}q_i(Z_i)\log q_i(Z_i)dZ_i=\int_{Z_j}q_j(Z_j)\log q_j(Z_j)dZ_j+Const Zq(Z)logq(Z)dZ=i=1MZiqi(Zi)logqi(Zi)dZi=Zjqj(Zj)logqj(Zj)dZj+Const
两项相减,令 E ∏ i ≠ j q i ( Z i ) [ log ⁡ p ( X , Z ) ] = log ⁡ p ^ ( X , Z j ) \mathbb{E}_{\prod\limits_{i\ne j}q_i(Z_i)}[\log p(X,Z)]=\log \hat{p}(X,Z_j) Ei=jqi(Zi)[logp(X,Z)]=logp^(X,Zj) 可以得到:
− ∫ Z j q j ( Z j ) log ⁡ q j ( Z j ) p ^ ( X , Z j ) d Z j ≤ 0 -\int_{Z_j}q_j(Z_j)\log\frac{q_j(Z_j)}{\hat{p}(X,Z_j)}dZ_j\le 0 Zjqj(Zj)logp^(X,Zj)qj(Zj)dZj0
于是最大的 q j ( Z j ) = p ^ ( X , Z j ) q_j(Z_j)=\hat{p}(X,Z_j) qj(Zj)=p^(X,Zj) 才能得到最大值。我们看到,对每一个 q j q_j qj,都是固定其余的 q i q_i qi,求这个值,于是可以使用坐标上升的方法进行迭代求解,上面的推导针对单个样本,但是对数据集也是适用的。

基于平均场假设的变分推断存在一些问题:

  1. 假设太强, Z Z Z 非常复杂的情况下,假设不适用
  2. 期望中的积分,可能无法计算

SGVI

Z Z Z X X X 的过程叫做生成过程或译码,反过来的额过程叫推断过程或编码过程,基于平均场的变分推断可以导出坐标上升的算法,但是这个假设在一些情况下假设太强,同时积分也不一定能算。我们知道,优化方法除了坐标上升,还有梯度上升的方式,我们希望通过梯度上升来得到变分推断的另一种算法。

我们的目标函数:
q ^ ( Z ) = a r g m a x q ( Z ) L ( q ) \hat{q}(Z)=\mathop{argmax}_{q(Z)}L(q) q^(Z)=argmaxq(Z)L(q)
假定 q ( Z ) = q ϕ ( Z ) q(Z)=q_\phi(Z) q(Z)=qϕ(Z),是和 ϕ \phi ϕ 这个参数相连的概率分布。于是 a r g m a x q ( Z ) L ( q ) = a r g m a x ϕ L ( ϕ ) \mathop{argmax}_{q(Z)}L(q)=\mathop{argmax}_{\phi}L(\phi) argmaxq(Z)L(q)=argmaxϕL(ϕ),其中 L ( ϕ ) = E q ϕ [ log ⁡ p θ ( x i , z ) − log ⁡ q ϕ ( z ) ] L(\phi)=\mathbb{E}_{q_\phi}[\log p_\theta(x^i,z)-\log q_\phi(z)] L(ϕ)=Eqϕ[logpθ(xi,z)logqϕ(z)],这里 x i x^i xi 表示第 i i i 个样本。
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲\nabla_\phi L(\…
这个期望可以通过蒙特卡洛采样来近似,从而得到梯度,然后利用梯度上升的方法来得到参数:
z l ∼ q ϕ ( z ) E q ϕ [ ( ∇ ϕ log ⁡ q ϕ ) ( log ⁡ p θ ( x i , z ) − log ⁡ q ϕ ( z ) ) ] ∼ 1 L ∑ l = 1 L ( ∇ ϕ log ⁡ q ϕ ) ( log ⁡ p θ ( x i , z ) − log ⁡ q ϕ ( z ) ) z^l\sim q_\phi(z)\\ \mathbb{E}_{q_\phi}[(\nabla_\phi\log q_\phi)(\log p_\theta(x^i,z)-\log q_\phi(z))]\sim \frac{1}{L}\sum\limits_{l=1}^L(\nabla_\phi\log q_\phi)(\log p_\theta(x^i,z)-\log q_\phi(z)) zlqϕ(z)Eqϕ[(ϕlogqϕ)(logpθ(xi,z)logqϕ(z))]L1l=1L(ϕlogqϕ)(logpθ(xi,z)logqϕ(z))
但是由于求和符号中存在一个对数项,于是直接采样的方差很大,需要采样的样本非常多。为了解决方差太大的问题,我们采用 Reparameterization 的技巧。

考虑:
∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ log ⁡ p θ ( x i , z ) − log ⁡ q ϕ ( z ) ] \nabla_\phi L(\phi)=\nabla_\phi\mathbb{E}_{q_\phi}[\log p_\theta(x^i,z)-\log q_\phi(z)] ϕL(ϕ)=ϕEqϕ[logpθ(xi,z)logqϕ(z)]
我们取: z = g ϕ ( ε , x i ) , ε ∼ p ( ε ) z=g_\phi(\varepsilon,x^i),\varepsilon\sim p(\varepsilon) z=gϕ(ε,xi),εp(ε),于是对后验: z ∼ q ϕ ( z ∣ x i ) z\sim q_\phi(z|x^i) zqϕ(zxi),有 ∣ q ϕ ( z ∣ x i ) d z ∣ = ∣ p ( ε ) d ε ∣ |q_\phi(z|x^i)dz|=|p(\varepsilon)d\varepsilon| qϕ(zxi)dz=p(ε)dε。代入上面的梯度中:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \nabla_\phi L(…
对这个式子进行蒙特卡洛采样,然后计算期望,得到梯度。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值