变分推断公式推导

变分推断公式推导

背景介绍

机器学习中的概率模型可分为频率派和贝叶斯派。频率派最终是求一个优化问题,而贝叶斯派则是求一个积分问题

频率派

举几个例子:

线性回归

样本数据: { ( x i , y i ) } i = 1 N \{(x_i,y_i)\}_{i=1}^N {(xi,yi)}i=1N

  • 模型: f ( w ) = w T x f(w)=w^Tx f(w)=wTx

  • 策略:损失函数: L ( w ) = ∑ i = 1 N ∣ ∣ w T x i − y i ∣ ∣ 2 L(w)=\sum_{i=1}^N||w^Tx_i-y_i||^2 L(w)=i=1N∣∣wTxiyi2 w ^ = arg ⁡ min ⁡ w L ( w ) \hat{w}=\arg\min_wL(w) w^=argminwL(w) 这就是一个无约束优化问题。

  • 算法:解法

    • 解析解:线性回归问题形式比较简单,可直接由最小二乘法求出解析解: w ∗ = ( X T X ) − 1 X T Y w^*=(X^TX)^{-1}X^TY w=(XTX)1XTY
    • 数值解:对于其他较为复杂的算法无法解析。有一些求数值解的方法,如梯度下降等。

SVM

  • 模型: f ( w ) = s i g n ( w T x + b ) f(w)=sign(w^Tx+b) f(w)=sign(wTx+b)
  • 策略:损失函数: m i n 1 2 w T w      s . t .   y i ( w T x i + b ≥ 1 ) min\frac{1}{2}w^Tw\ \ \ \ s.t.\ y_i(w^Tx_i+b\ge 1) min21wTw    s.t. yi(wTxi+b1)。 是一个有约束的凸优化问题。
  • 算法:解法有 QP、拉格朗日对偶等。

EM算法
θ ( t + 1 ) = arg ⁡ max ⁡ θ ∫ Z log ⁡ P ( X , Z ∣ θ ) P ( Z ∣ X , θ ( t ) ) d Z \theta^{(t+1)}=\arg\max_{\theta}\int_Z\log P(X,Z|\theta)P(Z|X,\theta^{(t)})dZ θ(t+1)=argθmaxZlogP(X,Zθ)P(ZX,θ(t))dZ
EM算法也是通过迭代来求解最大对数似然的数值解。

贝叶斯派

为什么说贝叶斯派是求积分呢?我们先来看贝叶斯定理:
P ( θ ∣ X ) = P ( X ∣ θ ) P ( θ ) P ( X ) P(\theta|X)=\frac{P(X|\theta)P(\theta)}{P(X)} P(θX)=P(X)P(Xθ)P(θ)
贝叶斯推断,要求得后验 P ( θ ∣ X ) P(\theta|X) P(θX)

贝叶斯决策。决策可以理解为就是做预测。即 X X X 为已知的 N N N 个样本数据。决策就是求:
P ( x ~ ∣ X ) = ∫ θ P ( x ~ ∣ X ) d θ = ∫ θ P ( x ~ ∣ θ ) P ( θ ∣ X ) d θ P(\tilde{x}|X)=\int_\theta P(\tilde{x}|X)d\theta=\int_\theta P(\tilde{x}|\theta)P(\theta|X)d\theta P(x~X)=θP(x~X)dθ=θP(x~θ)P(θX)dθ
在通过贝叶斯推断求得后验 P ( θ ∣ X ) P(\theta|X) P(θX) 之后,就可以按照上式进行贝叶斯决策。而且上面这个式子也可以写成关于后验的期望的形式(期望就是求积分):
P ( x ~ ∣ X ) = E θ ∣ X [ P ( x ~ ∣ θ ) ] P(\tilde{x}|X)=\mathbb{E}_{\theta|X}[P(\tilde{x}|\theta)] P(x~X)=EθX[P(x~θ)]
贝叶斯派的关键就是求得后验 P ( θ ∣ X ) P(\theta|X) P(θX) ,即贝叶斯推断的过程。贝叶斯推断又可分为精确推断和近似推断:

  • 精确推断
  • 近似推断
    • 确定性近似:变分推断(本文的主题)
    • 随机近似:MCMC、MH、Gibbs

公式推导

符号含义: X X X 为观测数据, Z Z Z 为隐变量和参数。注意这里参数 θ \theta θ 也一同表示在 Z Z Z 中了。

再强调一下我们的目的:求后验 P ( Z ∣ X ) P(Z|X) P(ZX)

下面的前几步与 EM 算法导出的做法类似,详见 EM算法公式推导 ,区别只是把参数 θ \theta θ 合并到了 Z Z Z 中,步骤这里就不一一说明了。
log ⁡ P ( X ) = log ⁡ P ( X , Z ) − log ⁡ P ( Z ∣ X ) = log ⁡ P ( X , Z ) q ( Z ) − log ⁡ P ( Z ∣ X ) q ( Z ) = ∫ Z q ( Z ) log ⁡ P ( X , Z ) q ( Z ) d Z − ∫ Z q ( Z ) log ⁡ P ( Z ∣ X ) q ( Z ) d Z = E L B O + K L ( q ( Z ) ∣ ∣ P ( Z ∣ X ) ) = L ( q ) + K L ( q ( Z ) ∣ ∣ P ( Z ∣ X ) ) \begin{align} \log P(X)&=\log P(X,Z)-\log P(Z|X)\\ &=\log \frac{P(X,Z)}{q(Z)}-\log \frac{P(Z|X)}{q(Z)}\\ &=\int_Zq(Z)\log\frac{P(X,Z)}{q(Z)}dZ-\int_Zq(Z)\log \frac{P(Z|X)}{q(Z)}dZ\\ &=ELBO+KL(q(Z)||P(Z|X))\\ &=\mathcal{L}(q)+KL(q(Z)||P(Z|X)) \end{align} logP(X)=logP(X,Z)logP(ZX)=logq(Z)P(X,Z)logq(Z)P(ZX)=Zq(Z)logq(Z)P(X,Z)dZZq(Z)logq(Z)P(ZX)dZ=ELBO+KL(q(Z)∣∣P(ZX))=L(q)+KL(q(Z)∣∣P(ZX))
经过一系列变形,得到 E B L O + K L EBLO+KL EBLO+KL 的形式,这里我们将 E L B O ELBO ELBO 记为 L ( q ) \mathcal{L}(q) L(q) ,就是所谓的变分

我们是要求的是后验 P ( Z ∣ X ) P(Z|X) P(ZX) ,如果其与 q ( Z ) q(Z) q(Z) 的 KL 散度接近0,那么就能用 q ( Z ) q(Z) q(Z) 来对其进行近似。而等式左边 log ⁡ P ( X ) \log P(X) logP(X) Z Z Z 无关,因此 E L B O + K L ELBO+KL ELBO+KL q ( Z ) q(Z) q(Z) 变化时是个定值,因此,要让 KL 尽量小就转换为让 ELBO 尽量大,即有:
q ^ ( Z ) = arg ⁡ max ⁡ q ( Z ) L ( q )      →      q ( Z ) ≈ P ( Z ∣ X ) \hat{q}(Z)=\arg\max_{q(Z)}\mathcal{L}(q)\ \ \ \ \rightarrow\ \ \ \ q(Z)\approx P(Z|X) q^(Z)=argq(Z)maxL(q)        q(Z)P(ZX)
接下来,我们根据平均场理论,将 q ( Z ) q(Z) q(Z) 划分为 M M M相互独立的份:
q ( Z ) = ∏ i = 1 M q i ( Z i ) q(Z)=\prod_{i=1}^Mq_i(Z_i) q(Z)=i=1Mqi(Zi)
之后在求解的时候,我们会先固定 q 1 , q 2 , … , q j − 1 , … , q M q_1,q_2,\dots,q_{j-1},\dots,q_M q1,q2,,qj1,,qM ,然后求解单个分量 q j q_j qj ,最后将所有分量连乘起来,得到完整的 q ( Z ) q(Z) q(Z)

首先先将 q ( Z ) q(Z) q(Z) 代回到原式中:
L ( q ) = ∫ Z q ( Z ) log ⁡ P ( X , Z ) d Z − ∫ Z log ⁡ q ( Z ) d Z = ① − ② \mathcal{L}(q)=\int_Zq(Z)\log P(X,Z)dZ-\int_Z\log q(Z)dZ=①-②\\ L(q)=Zq(Z)logP(X,Z)dZZlogq(Z)dZ=
一项一项地来看:
① = ∫ Z q ( Z ) log ⁡ P ( X , Z ) d Z = ∫ Z ∏ i = 1 M q i ( Z i ) log ⁡ P ( X , Z ) d Z = ∫ Z j q j ( Z j ) ∫ Z i ( i ≠ j ) ∏ i ≠ j M q i ( Z i ) log ⁡ P ( X , Z ) d Z i ( i ≠ j ) d Z j = ∫ Z j q j ( Z j ) ∫ Z i ( i ≠ j ) log ⁡ P ( X , Z ) ∏ i ≠ j M q i ( Z i ) d Z i ( i ≠ j ) d Z j = ∫ Z j q j ( Z j ) ⋅ E ∏ i ≠ j M q i ( Z i ) [ log ⁡ P ( X , Z ) ] d Z j \begin{align} ①&=\int_Zq(Z)\log P(X,Z)dZ\\ &=\int_Z\prod_{i=1}^Mq_i(Z_i)\log P(X,Z)dZ\\ &=\int_{Z_j}q_j(Z_j)\int_{Z_i(i\ne j)}\prod_{i\ne j}^Mq_i(Z_i)\log P(X,Z)dZ_{i(i\ne j)}dZ_j\\ &=\int_{Z_j}q_j(Z_j)\int_{Z_i(i\ne j)}\log P(X,Z)\prod_{i\ne j}^Mq_i(Z_i)dZ_{i(i\ne j)}dZ_j\\ &=\int_{Z_j}q_j(Z_j)\cdot\mathbb{E}_{\prod_{i\ne j}^Mq_i(Z_i)}[\log P(X,Z)]dZ_j \end{align} =Zq(Z)logP(X,Z)dZ=Zi=1Mqi(Zi)logP(X,Z)dZ=Zjqj(Zj)Zi(i=j)i=jMqi(Zi)logP(X,Z)dZi(i=j)dZj=Zjqj(Zj)Zi(i=j)logP(X,Z)i=jMqi(Zi)dZi(i=j)dZj=Zjqj(Zj)Ei=jMqi(Zi)[logP(X,Z)]dZj

  • 先将 q ( Z ) q(Z) q(Z) 进行拆分为 M M M 份;
  • 然后将第 j j j 份拆出来;
  • 其他份的积分写成期望的形式(见到积分,就考虑能写成期望)

然后看后面一项:
② = ∫ Z q ( Z ) log ⁡ q ( Z ) d Z = ∫ Z ∏ i = 1 M q i ( Z i ) log ⁡ ∏ i = 1 M q i ( Z i ) d Z = ∫ Z ∏ i = 1 M q i ( Z i ) ∑ i = 1 M log ⁡ q i ( Z i ) d Z = ∫ Z ∏ i = 1 M q i ( Z i ) [ log ⁡ q 1 ( Z 1 ) + log ⁡ q 2 ( Z 2 ) + ⋯ + log ⁡ q M ( Z M ) ] d Z \begin{align} ②&=\int_Zq(Z)\log q(Z)dZ\\ &=\int_Z\prod_{i=1}^Mq_i(Z_i)\log\prod_{i=1}^M q_i(Z_i)dZ\\ &=\int_Z\prod_{i=1}^Mq_i(Z_i)\sum_{i=1}^M\log q_i(Z_i)dZ\\ &=\int_Z\prod_{i=1}^Mq_i(Z_i)[\log q_1(Z_1)+\log q_2(Z_2)+\dots+\log q_M(Z_M)]dZ\\ \end{align} =Zq(Z)logq(Z)dZ=Zi=1Mqi(Zi)logi=1Mqi(Zi)dZ=Zi=1Mqi(Zi)i=1Mlogqi(Zi)dZ=Zi=1Mqi(Zi)[logq1(Z1)+logq2(Z2)++logqM(ZM)]dZ

  • 写成 M M M 份;
  • log 里面乘变外面加;
  • 把连加号写开;
  • 然后我们看其中一项(比如第一项):

∫ Z ∏ i = 1 M q i ( Z i ) ⋅ log ⁡ q 1 ( Z 1 ) d Z = ∫ Z q 1 ( Z 1 ) q 2 ( Z 2 ) … q M ( Z M ) log ⁡ q 1 ( Z 1 ) d Z = ∫ Z 1 Z 2 … Z M q 1 ( Z 1 ) q 2 ( Z 2 ) … q M ( Z M ) log ⁡ q 1 ( Z 1 ) d Z 1 d Z 2 … d Z M = ∫ Z 1 q 1 ( Z 1 ) log ⁡ q 1 ( Z 1 ) d Z 1 ∏ i = 2 M ∫ Z i q i ( Z i ) d Z i = ∫ Z 1 q 1 ( Z 1 ) log ⁡ q 1 ( Z 1 ) d Z 1 \begin{align} \int_Z\prod_{i=1}^Mq_i(Z_i)\cdot\log q_1(Z_1)dZ&=\int_Zq_1(Z_1)q_2(Z_2)\dots q_M(Z_M)\log q_1(Z_1)dZ\\ &=\int_{Z_1Z_2\dots Z_M}q_1(Z_1)q_2(Z_2)\dots q_M(Z_M)\log q_1(Z_1)dZ_1dZ_2\dots dZ_M\\ &=\int_{Z_1}q_1(Z_1)\log q_1(Z_1)dZ_1\prod_{i=2}^M\int_{Z_i}q_i(Z_i)dZ_i\\ &=\int_{Z_1}q_1(Z_1)\log q_1(Z_1)dZ_1 \end{align} Zi=1Mqi(Zi)logq1(Z1)dZ=Zq1(Z1)q2(Z2)qM(ZM)logq1(Z1)dZ=Z1Z2ZMq1(Z1)q2(Z2)qM(ZM)logq1(Z1)dZ1dZ2dZM=Z1q1(Z1)logq1(Z1)dZ1i=2MZiqi(Zi)dZi=Z1q1(Z1)logq1(Z1)dZ1

  • q 1 ( Z 1 ) q_1(Z_1) q1(Z1) 相关的移到一起;
  • 剩下的积分全都是 1

② = ∑ i = 1 M ∫ Z i q i ( Z i ) log ⁡ q i ( Z i ) d Z i = ∫ Z j q j ( Z j ) log ⁡ q j ( Z j ) d Z j + C \begin{align} ②&=\sum_{i=1}^M\int_{Z_i}q_i(Z_i)\log q_i(Z_i)dZ_i\\ &=\int_{Z_j}q_j(Z_j)\log q_j(Z_j)dZ_j+C\\ \end{align} =i=1MZiqi(Zi)logqi(Zi)dZi=Zjqj(Zj)logqj(Zj)dZj+C

  • 有了 i = 1 i=1 i=1 时的表示,我们就把整个第二项写成连加的形式;
  • 我们只关心第 j j j 项,其余的视作常数 C C C

这样处理完两项,有:
① − ② = ∫ Z j q j ( Z j ) ⋅ E ∏ i ≠ j M q i ( Z i ) [ log ⁡ P ( X , Z ) ] d Z j − ∫ Z j q j ( Z j ) log ⁡ q j ( Z j ) d Z j + C = ∫ Z j q j ( Z j ) ⋅ log ⁡ P ^ ( X , Z j ) d Z j − ∫ Z j q j ( Z j ) log ⁡ q j ( Z j ) d Z j + C = ∫ Z j q j ( Z j ) ⋅ log ⁡ P ^ ( X , Z j ) q j ( Z j ) d Z j = − K L ( P ^ ( X , Z j ) ∣ ∣ q j ( Z j ) ) ≤ 0 \begin{align} ①-②&=\int_{Z_j}q_j(Z_j)\cdot\mathbb{E}_{\prod_{i\ne j}^Mq_i(Z_i)}[\log P(X,Z)]dZ_j-\int_{Z_j}q_j(Z_j)\log q_j(Z_j)dZ_j+C\\ &=\int_{Z_j}q_j(Z_j)\cdot\log \hat{P}(X,Z_j) dZ_j-\int_{Z_j}q_j(Z_j)\log q_j(Z_j)dZ_j+C\\ &=\int_{Z_j}q_j(Z_j)\cdot\log\frac{ \hat{P}(X,Z_j)}{q_j(Z_j)}dZ_j\\ &=-KL(\hat{P}(X,Z_j)||q_j(Z_j))\le 0 \end{align} =Zjqj(Zj)Ei=jMqi(Zi)[logP(X,Z)]dZjZjqj(Zj)logqj(Zj)dZj+C=Zjqj(Zj)logP^(X,Zj)dZjZjqj(Zj)logqj(Zj)dZj+C=Zjqj(Zj)logqj(Zj)P^(X,Zj)dZj=KL(P^(X,Zj)∣∣qj(Zj))0

  • 将 ① 中的期望写成一个函数的形式: P ^ ( X , Z j ) \hat{P}(X,Z_j) P^(X,Zj)
  • 最后就是一个负的 KL 散度,当 P ^ ( X , Z j ) = q j ( Z j ) \hat{P}(X,Z_j)=q_j(Z_j) P^(X,Zj)=qj(Zj) 时取到等号

Ref

  1. 机器学习白板推导
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值