变分推断学习

#! https://zhuanlan.zhihu.com/p/401456634

变分推断

1.变分推断的背景

   在机器学习中,有很多求后验概率的问题,求后验概率的过程被称为推断(Inference),推断分为精确推断和近似推断。精确推断一般主要是根据贝叶斯等概率公式推导出后验概率,但在一些生成模型中,如RBM, DBN, DBM很难应用精确推断,于是乎我们就有了近似推断,近似推断又分为确定性近似和随机性近似,确定性近似方法就是变分推断(Variance Inference, VI),随机性近似的方法有马尔可夫蒙特卡洛采样(Markov Chain Monte Carlo, MCMC)、Metropolis-Hastings采样(MH)、吉布斯采样(Gibbs)。具体关系如下图:
image

2.变分推断介绍

2.1.问题的提出

假设 X X X是观测数据, Z Z Z是隐变量+参数, ( X , Z ) (X, Z) (X,Z)是完全数据。
在最大似然估计里面有这么个事实:
log ⁡ p ( x ) = log ⁡ p ( x , z ) − log ⁡ p ( z ∣ x ) = log ⁡ p ( x , z ) q ( z ) − log ⁡ p ( z ∣ x ) q ( z ) \begin{aligned} \log p(x) &=\log p(x, z)-\log p(z \mid x) \\ &=\log \frac{p(x, z)}{q(z)}-\log \frac{p(z \mid x)}{q(z)} \end{aligned} logp(x)=logp(x,z)logp(zx)=logq(z)p(x,z)logq(z)p(zx)
两边都关于 z z z积分有,左边= ∫ z log ⁡ p ( x ) q ( z ) d z = log ⁡ p ( x ) \int_{z} \log p(x) q(z) d z=\log p(x) zlogp(x)q(z)dz=logp(x),而右边,
∫ z q ( z ) log ⁡ p ( x , z ) q ( z ) d z + ( − ∫ z q ( z ) log ⁡ p ( z ∣ x ) q ( z ) d z ) = L ( q ) + K L ( q ∥ p ) (1) \begin{aligned} & \int_{z} q(z) \log \frac{p(x, z)}{\left.q{(z}\right)} d z+\left(-\int_{z} q(z) \log \frac{p(z \mid x)}{\left.q(z\right)} d z\right) \\ =& \mathcal{L}(q)+KL(q \| p) \end{aligned}\tag{1} =zq(z)logq(z)p(x,z)dz+(zq(z)logq(z)p(zx)dz)L(q)+KL(qp)(1)
其中 L ( q ) \mathcal{L}(q) L(q)就是我们说的变分,也即ELBO,这样就把对数似然转化为了变分和 K L KL KL散度的和。

因为后验函数 p ( z ∣ x ) p(z|x) p(zx)求不出,所以我们的任务就是要找到一个分布 q ( z ) q(z) q(z)来近似这个后验,也就是使 K L ( q ∥ p ) KL(q \| p) KL(qp)最小,此时要求变分 L ( q ) \mathcal{L}(q) L(q)达到最大值,也即把问题转化为了求 q ( z ) = a r g m a x q L ( q ) q(z) = \underset{q}{argmax}\mathcal{L}(q) q(z)=qargmaxL(q)

2.2.问题的转化

在介绍变分问题前,我们还需要了解到一个概念,平均场假设,这个假设来源于统计物理中的mean field theory思想,将 q ( z ) q(z) q(z)划分为M个独立的分布,即 q ( z ) = ∏ i = 1 M q i ( z i ) q(z)=\prod_{i=1}^{M}q_{i}(z_{i}) q(z)=i=1Mqi(zi)
代入 L ( q ) \mathcal{L}(q) L(q)中有:
L ( q ) = ∫ z q ( z ) log ⁡ p ( x , z ) d z − ∫ z q ( z ) log ⁡ q ( z ) d z \mathcal{L}(q)=\int_{z} q(z) \log p(x, z) d z-\int_{z} q(z) \log q(z) d z L(q)=zq(z)logp(x,z)dzzq(z)logq(z)dz,其中第一部分:
∫ z q ( z ) log ⁡ p ( x , z ) d z = ∫ z ∏ i = 1 m q i ( z i ) log ⁡ p ( x , z ) d z = ∫ z j q j ( z j ) [ ∫ ∏ i ≠ j M q i ( z i ) log ⁡ p ( x , z ) d z 1 d z 2 ⋯ d z M ] d z j = ∫ z j q j ( z j ) ⋅ E ∏ i ≠ j M q i ( z i ) [ log ⁡ p ( x , z ) ] d z j = ∫ z j q j ( z j ) ⋅ l o g p ^ ( x , z ) ] d z j \begin{aligned} \int_{z} q(z) \log p(x, z) d z &=\int_{z} \prod_{i=1}^{m} q_{i}\left(z_{i}\right) \log p\left(x, z\right) d z \\ &=\int_{z j} q_{j}\left(z_{j}\right)\left[\int \prod_{i \neq j}^{M} q_{i}\left(z_{i}\right) \log p(x, z)dz_1dz_2\cdots dz_M \right]dz_j\\ &=\int_{z_{j}} q_{j}\left(z_{j}\right) \cdot E_{\prod_{i \neq j}^{M} q_{i}\left(z_{i}\right)}[\log p(x, z)] d z_{j}\\ &=\int_{z_{j}} q_{j}\left(z_{j}\right) \cdot log \hat{p}(x, z)] d z_{j}\end{aligned} zq(z)logp(x,z)dz=zi=1mqi(zi)logp(x,z)dz=zjqj(zj)i=jMqi(zi)logp(x,z)dz1dz2dzMdzj=zjqj(zj)Ei=jMqi(zi)[logp(x,z)]dzj=zjqj(zj)logp^(x,z)]dzj
∫ z q ( z ) log ⁡ q ( z ) d z = ∫ z ∏ i = 1 M q i ( z i ) log ⁡ ∏ i = 1 M q i ( z i ) d z = ∫ z ∏ i = 1 m q i ( z i ) ∑ i = 1 m log ⁡ q i ( z i ) d z \begin{aligned} \int_{z} q(z) \log q(z) d z &=\int_{z} \prod_{i=1}^{M} q_{i} \left(z_{i}\right) \log \prod_{i=1}^{M} q_{i}\left(z_i\right) d z \\ &=\int_{z} \prod_{i=1}^{m} q_{i}\left(z_{i}\right) \sum_{i=1}^{m} \log q_{i}\left(z_{i}\right) d z \end{aligned} zq(z)logq(z)dz=zi=1Mqi(zi)logi=1Mqi(zi)dz=zi=1mqi(zi)i=1mlogqi(zi)dz
因为
∫ z ∏ i = 1 M q i ( z i ) log ⁡ q 1 d z = ∫ z 1 q 1 ( z 1 ) log ⁡ q 1 d z 1 ∫ z 2 q 2 ( z 2 ) d z 2 ⋯ ∫ z m q m ( z m ) d z m = ∫ z 1 q 1 ( z 1 ) log ⁡ q 1 d z 1 \begin{aligned} \int_{z} \prod_{i=1}^{M}q_{i}\left(z_{i}\right) \log q_{1} d z &=\int_{z_{1}}q_{1}\left(z_{1}\right) \log q_{1} d z_{1} \int_{z_{2}} q_{2}\left(z_{2}\right) d z_{2} \cdots \int_{z_{m}} q_{m}\left(z_{m}\right) d z_{m} \\ &=\int_{z_{1}}q_{1}\left(z_{1}\right) \log q_{1} d z_{1} \end{aligned} zi=1Mqi(zi)logq1dz=z1q1(z1)logq1dz1z2q2(z2)dz2zmqm(zm)dzm=z1q1(z1)logq1dz1
所以继续化简有:
∫ z q ( z ) log ⁡ q ( z ) d z = ∑ i = 1 m ∫ z i q i ( z i ) log ⁡ q i ( z i ) d z i = ∫ z j q j ( z j ) log ⁡ q j ( z j ) d z j + C \begin{aligned} \int_{z} q(z) \log q(z) d z &=\sum_{i=1}^{m} \int_{z_{i}} q_{i}\left(z_{i}\right) \log q_{i}\left(z_{i}\right) d z_{i} \\ &=\int_{z_{j}} q_{j}\left(z_{j}\right) \log q_{j}\left(z_{j}\right) d z_{j}+C \end{aligned} zq(z)logq(z)dz=i=1mziqi(zi)logqi(zi)dzi=zjqj(zj)logqj(zj)dzj+C

此时只把 q j ( z j ) q_j(z_j) qj(zj)看作变量,其余看作常量。所以
L ( q ) = ∫ z j q j ( z j ) log ⁡ p ^ ( x , z ) q j ( z j ) d z j − C \mathcal{L}(q)=\int_{z_{j}} q_{j}\left(z_{j}\right) \log \frac{\hat{p}(x, z)}{q_{j}(z_{j})} d z_{j}-C L(q)=zjqj(zj)logqj(zj)p^(x,z)dzjC因为 C C C是常量,z在求极大的时候可以省略。故我们得到了:
L ( q ) = ∫ z j q j ( z j ) log ⁡ p ^ ( x , z j ) q j ( z j ) d z j = − K L ( q j ∥ p ^ ( x , z j ) ) ≤ 0 \begin{aligned}\mathcal{L}(q)&=\int_{z_{j}} q_{j}\left(z_{j}\right) \log \frac{\hat{p}(x, z_j)}{q_{j}(z_{j})} d z_{j}\\ &=-KL(q_j\|\hat{p}(x, z_j)) \le 0\end{aligned} L(q)=zjqj(zj)logqj(zj)p^(x,zj)dzj=KL(qjp^(x,zj))0

注意,以上推导都是建立在平均场假设上。
  关于推导过程中的 K L KL KL散度有一些细节问题,要知道 K L KL KL散度是不对称的, K L ( q ∥ p ) 和 K L ( p ∥ q ) KL(q\|p)和KL(p\|q) KL(qp)KL(pq)有着不同的性质, K L ( q ∥ p ) KL(q\|p) KL(qp)鼓励分布 q q q在真实分布 p p p达到高概率达到高概率, K L ( p ∥ q ) KL(p\|q) KL(pq)则鼓励分布 q q q在真实分布 p p p概率较低的地方概率较小,他们各自有其优缺点,应用则取决于两者哪种效果更好。出于计算的角度,我们选择用 K L ( q ∥ p ) KL(q\|p) KL(qp),因为其涉及求在分布 q q q下的数学期望,比起求在真实分布 p p p下的数学期望较为简单。
  实际上,VI和EM的方法有些类似,关于EM算法,下次再写。

2.3.问题的解决

  至此,我们推导出了变分的形式,变分学习的核心思想就是在一个关于 q q q的有约束的分布族上最大化 L \mathcal{L} L。要使 q = p q=p q=p,则有 L = 0 \mathcal{L}=0 L=0,根据上面的推导过程,此时有公式
log ⁡ q j ( z j ) = E ∏ i ≠ j [ log ⁡ p ( x , z ) ] + C \log q_j(z_j)=E_{\prod_{i\ne j}}[\log p(x, z)]+C logqj(zj)=Ei=j[logp(x,z)]+C
展开就有
log ⁡ q j ( z j ) = ∫ z 1 ∫ z 2 ⋯ ∫ z j − 1 ∫ z j + 1 ⋯ ∫ z M q 1 q 2 ⋯ q j − 1 q j + 1 ⋯ q M log ⁡ p ( x , z ) d z 1 d z 2 ⋯ d z j − 1 d z j + 1 ⋯ d M \log q_j(z_j)=\int_{z_1}\int_{z_2}\cdots \int_{z_{j-1}}\int_{z_{j+1}}\cdots\int_{z_{M}}q_1q_2\cdots q_{j-1}q_{j+1}\cdots q_{M}\log p(x, z)dz_1dz_2\cdots d_{z_j-1}d_{z_j+1}\cdots d_{M} logqj(zj)=z1z2zj1zj+1zMq1q2qj1qj+1qMlogp(x,z)dz1dz2dzj1dzj+1dM
然后我们用常规的迭代算法求解上式,比如坐标上升发就可以得到 log ⁡ q j ( z j ) \log q_j(z_j) logqj(zj),然后求出所有的 q j q_j qj就可以得到后验概率了,目的达到!

2.4.变分的缺点

通过上面的推导我们可以发现所有推导都是建立在平均场假设上的,而平均场假设本身就较难满足,所以变分的一个主要缺点就是假设很难满足,不实用。

2.5.变分的变种

我们上面说可以用坐标上升法来计算,那我们很自然的就可以想到可不可以用梯度上升法(SGA),答案是可以的。
我们要求解的问题是
q = a r g m a x q L ( q ) q = \underset{q}{argmax}\mathcal{L}(q) q=qargmaxL(q)
q ( z ) q(z) q(z)写成 q Φ ( z ) q_{\Phi}(z) qΦ(z),这里应用了以下重参数化的技巧,因为对概率密度函数求梯度是不容易的,所以我们抽象出了一个连续变量 Φ \Phi Φ,从而可以对其求导,这里的 Φ \Phi Φ仍然有
Φ = a r g m a x Φ L ( Φ ) \Phi=\underset{\Phi}{argmax}\mathcal{L}(\Phi) Φ=ΦargmaxL(Φ)
因为
L ( q ) = ∫ z q Φ ( z ) log ⁡ p ( x , z ) q Φ ( z ) d z = E q Φ [ log ⁡ p θ ( x , z ) − log ⁡ q Φ ] \mathcal{L}(q)=\int_{z}q_{\Phi}(z)\log \frac{p(x, z)}{q_{\Phi}(z)}dz=E_{q_{\Phi}}[\log p_{\theta}(x, z)-\log q_{\Phi}] L(q)=zqΦ(z)logqΦ(z)p(x,z)dz=EqΦ[logpθ(x,z)logqΦ]
所以
▽ Φ L ( Φ ) = ▽ Φ E q Φ [ log ⁡ p θ ( x , z ) − log ⁡ q Φ ] = ▽ Φ ∫ z q Φ ( z ) [ log ⁡ p θ ( x , z ) − log ⁡ q Φ ] d z = ∫ z ▽ Φ q Φ ( z ) [ log ⁡ p θ ( x , z ) − log ⁡ q Φ ] d z + ∫ z q Φ ▽ Φ [ log ⁡ p θ ( x , z ) − log ⁡ q Φ ] d z = ① + ② \begin{aligned} \bigtriangledown_{\Phi}\mathcal{L}(\Phi) &=\bigtriangledown_{\Phi}E_{q_{\Phi}}[\log p_{\theta}(x, z)-\log q_{\Phi}] \\ &=\bigtriangledown_{\Phi}\int_zq_{\Phi}(z)[\log p_{\theta}(x, z) - \log q_{\Phi}]dz \\ &=\int_z\bigtriangledown_{\Phi}q_{\Phi}(z)[\log p_{\theta}(x, z) - \log q_{\Phi}]dz + \int_zq_{\Phi}\bigtriangledown_{\Phi}[\log p_{\theta}(x, z) - \log q_{\Phi}]dz \\ &=①+②\end{aligned} ΦL(Φ)=ΦEqΦ[logpθ(x,z)logqΦ]=ΦzqΦ(z)[logpθ(x,z)logqΦ]dz=zΦqΦ(z)[logpθ(x,z)logqΦ]dz+zqΦΦ[logpθ(x,z)logqΦ]dz=+
分析②
∫ z q Φ ▽ Φ [ log ⁡ p θ ( x , z ) − log ⁡ q Φ ] d z = − ∫ z q Φ ∗ 1 q Φ ∗ ▽ Φ q Φ d z = − ∫ z ▽ Φ q Φ d z = − ▽ Φ ∫ z q Φ d z = 0 \int_zq_{\Phi}\bigtriangledown_{\Phi}[\log p_{\theta}(x, z) - \log q_{\Phi}]dz=-\int_zq_{\Phi}*\frac{1}{q_{\Phi}}*\bigtriangledown_{\Phi}q_{\Phi}dz=-\int_z\bigtriangledown_{\Phi}q_{\Phi}dz=-\bigtriangledown_{\Phi}\int_zq_{\Phi}dz=0 zqΦΦ[logpθ(x,z)logqΦ]dz=zqΦqΦ1ΦqΦdz=zΦqΦdz=ΦzqΦdz=0
所以
▽ Φ L ( q ) = ∫ z ▽ Φ q Φ ( z ) [ log ⁡ p θ ( x , z ) − log ⁡ q Φ ] d z \bigtriangledown_{\Phi}\mathcal{L}(q)=\int_z\bigtriangledown_{\Phi}q_{\Phi}(z)[\log p_{\theta}(x, z) - \log q_{\Phi}]dz ΦL(q)=zΦqΦ(z)[logpθ(x,z)logqΦ]dz
用一个小技巧, ▽ Φ q Φ ( z ) = q Φ ▽ Φ log ⁡ q Φ \bigtriangledown_{\Phi}q_{\Phi}(z)=q_{\Phi}\bigtriangledown_{\Phi}\log q_{\Phi} ΦqΦ(z)=qΦΦlogqΦ,有
= ∫ z q Φ ▽ Φ log ⁡ q Φ [ log ⁡ p θ ( x , z ) − log ⁡ q Φ ] d z = E q Φ [ ▽ Φ log ⁡ q Φ [ log ⁡ p θ ( x , z ) − log ⁡ q Φ ] \begin{aligned} &=\int_zq_{\Phi}\bigtriangledown_{\Phi}\log q_{\Phi}[\log p_{\theta}(x, z) - \log q_{\Phi}]dz\\ &=E_{q_{\Phi}}[\bigtriangledown_{\Phi}\log q_{\Phi}[\log p_{\theta}(x, z) - \log q_{\Phi}] \end{aligned} =zqΦΦlogqΦ[logpθ(x,z)logqΦ]dz=EqΦ[ΦlogqΦ[logpθ(x,z)logqΦ]
此时就可以应用我们熟悉的采样方法来求解这个期望了,然后再进行梯度上升即可。
这里会有一个问题,就是期望中 ▽ Φ log ⁡ q Φ \bigtriangledown_{\Phi}\log q_{\Phi} ΦlogqΦ对样本点比较敏感,当 q Φ q_{\Phi} qΦ较小时,其梯度会趋于无穷,这就造成了采样结果方差较高,这就意味着我们需要大量的样本去拟合这个期望,可以认为不太现实,所以我们继续介绍求解这个梯度的另外一个方法。
重参数化技巧:
z ∼ g Φ ( ϵ , x ) , ϵ ∼ P ( ϵ ) z\sim g_{\Phi}(\epsilon, x), \epsilon \sim P(\epsilon) zgΦ(ϵ,x),ϵP(ϵ),在题解情况下有 z ∼ q ( z ∣ x ) ⇒ ϵ ∼ P ( ϵ ) z\sim q(z|x) \Rightarrow \epsilon \sim P(\epsilon) zq(zx)ϵP(ϵ),并且 q ( z ∣ x ) d z = P ( ϵ ) d ϵ q(z|x)dz=P(\epsilon)d\epsilon q(zx)dz=P(ϵ)dϵ,这里是因为 ∫ z q ( z ∣ x ) d z = 1 , ∫ ϵ P ( ϵ ) d ϵ = 1 \int_zq(z|x)dz=1,\int_{\epsilon}P(\epsilon)d\epsilon=1 zq(zx)dz=1ϵP(ϵ)dϵ=1并且 ϵ \epsilon ϵ z z z又有着对应关系,所以上述等式成立。
▽ Φ L ( Φ ) = ▽ Φ E q Φ [ log ⁡ p θ ( x , z ) − log ⁡ q Φ ] = ▽ Φ ∫ ( log ⁡ p θ ( x , z ) − log ⁡ q Φ ) q Φ d z = ▽ Φ ∫ ϵ ( log ⁡ p θ ( x , z ) − log ⁡ q Φ ) P ( ϵ ) d ϵ = ▽ Φ E P ( ϵ ) [ log ⁡ p θ ( x , z ) − log ⁡ q Φ ] = E P ( ϵ ) [ ▽ Φ ( log ⁡ p θ ( x , z ) − log ⁡ q Φ ) ] = E P ( ϵ ) [ ▽ z ( log ⁡ p θ ( x , z ) − log ⁡ q Φ ) ⋅ ▽ Φ z ] = E P ( ϵ ) [ ▽ z ( log ⁡ p θ ( x , z ) − log ⁡ q Φ ) ] ▽ Φ g Φ ( ϵ , x ) \begin{aligned} \bigtriangledown_{\Phi}\mathcal{L(\Phi)}&=\bigtriangledown_{\Phi}E_{q_{\Phi}}[\log p_{\theta}(x, z)-\log q_{\Phi}] \\ &=\bigtriangledown_{\Phi}\int(\log p_{\theta}(x, z)-\log q_{\Phi})q_{\Phi}dz \\ &=\bigtriangledown_{\Phi}\int_{\epsilon}(\log p_{\theta}(x, z)-\log q_{\Phi})P(\epsilon)d\epsilon \\ &=\bigtriangledown_{\Phi}E_{P(\epsilon)}[\log p_{\theta}(x, z)-\log q_{\Phi}] \\ &=E_{P(\epsilon)}[\bigtriangledown_{\Phi}(\log p_{\theta}(x, z)-\log q_{\Phi})] \\ &=E_{P(\epsilon)}[\bigtriangledown_{z}(\log p_{\theta}(x, z)-\log q_{\Phi})\cdot\bigtriangledown_{\Phi}z] \\ &=E_{P(\epsilon)}[\bigtriangledown_{z}(\log p_{\theta}(x, z)-\log q_{\Phi})]\bigtriangledown_{\Phi}g_{\Phi}(\epsilon, x) \end{aligned} ΦL(Φ)=ΦEqΦ[logpθ(x,z)logqΦ]=Φ(logpθ(x,z)logqΦ)qΦdz=Φϵ(logpθ(x,z)logqΦ)P(ϵ)dϵ=ΦEP(ϵ)[logpθ(x,z)logqΦ]=EP(ϵ)[Φ(logpθ(x,z)logqΦ)]=EP(ϵ)[z(logpθ(x,z)logqΦ)Φz]=EP(ϵ)[z(logpθ(x,z)logqΦ)]ΦgΦ(ϵ,x)
其中, P ( ϵ ) P(\epsilon) P(ϵ)是我们自己取的分布, 如均匀分布、正态分布,方便采样, ▽ z ( log ⁡ p θ ( x , z ) − log ⁡ q Φ ) \bigtriangledown_{z}(\log p_{\theta}(x, z)-\log q_{\Phi}) z(logpθ(x,z)logqΦ) ▽ Φ g Φ ( ϵ , x ) \bigtriangledown_{\Phi}g_{\Phi}(\epsilon, x) ΦgΦ(ϵ,x)都是已知的假设条件,此时再根据MCMC采样就可以得到梯度值了。然后根据梯度上升公式更新即可
Φ ( t + 1 ) = Φ ( t ) + λ ▽ Φ L ( Φ ) \Phi^{(t+1)}=\Phi^{(t)}+\lambda \bigtriangledown_{\Phi}\mathcal{L(\Phi)} Φ(t+1)=Φ(t)+λΦL(Φ)
这就是随机梯度变分法(Stochastic Gradient Variational Inference, SGVI)。

3.总结

  我们说明了变分的背景,以及变分推断的作用,是为了求一些在比较难求的后验,通常在无向图里使用。第二部分我们从最原始的条件推导变分,利用了平均场假设,说明了可以用坐标上升法和梯度上升法来求取后验。变分的主要作用就是求后验概率。
水平有限,如有错误,敬请指正。

  • 4
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值