变分推断公式推导
背景介绍
机器学习中的概率模型可分为频率派和贝叶斯派。频率派最终是求一个优化问题,而贝叶斯派则是求一个积分问题。
频率派
举几个例子:
线性回归
样本数据: { ( x i , y i ) } i = 1 N \{(x_i,y_i)\}_{i=1}^N {(xi,yi)}i=1N
-
模型: f ( w ) = w T x f(w)=w^Tx f(w)=wTx
-
策略:损失函数: L ( w ) = ∑ i = 1 N ∣ ∣ w T x i − y i ∣ ∣ 2 L(w)=\sum_{i=1}^N||w^Tx_i-y_i||^2 L(w)=∑i=1N∣∣wTxi−yi∣∣2, w ^ = arg min w L ( w ) \hat{w}=\arg\min_wL(w) w^=argminwL(w) 这就是一个无约束优化问题。
-
算法:解法
- 解析解:线性回归问题形式比较简单,可直接由最小二乘法求出解析解: w ∗ = ( X T X ) − 1 X T Y w^*=(X^TX)^{-1}X^TY w∗=(XTX)−1XTY
- 数值解:对于其他较为复杂的算法无法解析。有一些求数值解的方法,如梯度下降等。
SVM
- 模型: f ( w ) = s i g n ( w T x + b ) f(w)=sign(w^Tx+b) f(w)=sign(wTx+b)
- 策略:损失函数: m i n 1 2 w T w s . t . y i ( w T x i + b ≥ 1 ) min\frac{1}{2}w^Tw\ \ \ \ s.t.\ y_i(w^Tx_i+b\ge 1) min21wTw s.t. yi(wTxi+b≥1)。 是一个有约束的凸优化问题。
- 算法:解法有 QP、拉格朗日对偶等。
EM算法
θ
(
t
+
1
)
=
arg
max
θ
∫
Z
log
P
(
X
,
Z
∣
θ
)
P
(
Z
∣
X
,
θ
(
t
)
)
d
Z
\theta^{(t+1)}=\arg\max_{\theta}\int_Z\log P(X,Z|\theta)P(Z|X,\theta^{(t)})dZ
θ(t+1)=argθmax∫ZlogP(X,Z∣θ)P(Z∣X,θ(t))dZ
EM算法也是通过迭代来求解最大对数似然的数值解。
贝叶斯派
为什么说贝叶斯派是求积分呢?我们先来看贝叶斯定理:
P
(
θ
∣
X
)
=
P
(
X
∣
θ
)
P
(
θ
)
P
(
X
)
P(\theta|X)=\frac{P(X|\theta)P(\theta)}{P(X)}
P(θ∣X)=P(X)P(X∣θ)P(θ)
贝叶斯推断,要求得后验
P
(
θ
∣
X
)
P(\theta|X)
P(θ∣X) 。
贝叶斯决策。决策可以理解为就是做预测。即
X
X
X 为已知的
N
N
N 个样本数据。决策就是求:
P
(
x
~
∣
X
)
=
∫
θ
P
(
x
~
∣
X
)
d
θ
=
∫
θ
P
(
x
~
∣
θ
)
P
(
θ
∣
X
)
d
θ
P(\tilde{x}|X)=\int_\theta P(\tilde{x}|X)d\theta=\int_\theta P(\tilde{x}|\theta)P(\theta|X)d\theta
P(x~∣X)=∫θP(x~∣X)dθ=∫θP(x~∣θ)P(θ∣X)dθ
在通过贝叶斯推断求得后验
P
(
θ
∣
X
)
P(\theta|X)
P(θ∣X) 之后,就可以按照上式进行贝叶斯决策。而且上面这个式子也可以写成关于后验的期望的形式(期望就是求积分):
P
(
x
~
∣
X
)
=
E
θ
∣
X
[
P
(
x
~
∣
θ
)
]
P(\tilde{x}|X)=\mathbb{E}_{\theta|X}[P(\tilde{x}|\theta)]
P(x~∣X)=Eθ∣X[P(x~∣θ)]
贝叶斯派的关键就是求得后验
P
(
θ
∣
X
)
P(\theta|X)
P(θ∣X) ,即贝叶斯推断的过程。贝叶斯推断又可分为精确推断和近似推断:
- 精确推断
- 近似推断
- 确定性近似:变分推断(本文的主题)
- 随机近似:MCMC、MH、Gibbs
公式推导
符号含义: X X X 为观测数据, Z Z Z 为隐变量和参数。注意这里参数 θ \theta θ 也一同表示在 Z Z Z 中了。
再强调一下我们的目的:求后验 P ( Z ∣ X ) P(Z|X) P(Z∣X) 。
下面的前几步与 EM 算法导出的做法类似,详见 EM算法公式推导 ,区别只是把参数
θ
\theta
θ 合并到了
Z
Z
Z 中,步骤这里就不一一说明了。
log
P
(
X
)
=
log
P
(
X
,
Z
)
−
log
P
(
Z
∣
X
)
=
log
P
(
X
,
Z
)
q
(
Z
)
−
log
P
(
Z
∣
X
)
q
(
Z
)
=
∫
Z
q
(
Z
)
log
P
(
X
,
Z
)
q
(
Z
)
d
Z
−
∫
Z
q
(
Z
)
log
P
(
Z
∣
X
)
q
(
Z
)
d
Z
=
E
L
B
O
+
K
L
(
q
(
Z
)
∣
∣
P
(
Z
∣
X
)
)
=
L
(
q
)
+
K
L
(
q
(
Z
)
∣
∣
P
(
Z
∣
X
)
)
\begin{align} \log P(X)&=\log P(X,Z)-\log P(Z|X)\\ &=\log \frac{P(X,Z)}{q(Z)}-\log \frac{P(Z|X)}{q(Z)}\\ &=\int_Zq(Z)\log\frac{P(X,Z)}{q(Z)}dZ-\int_Zq(Z)\log \frac{P(Z|X)}{q(Z)}dZ\\ &=ELBO+KL(q(Z)||P(Z|X))\\ &=\mathcal{L}(q)+KL(q(Z)||P(Z|X)) \end{align}
logP(X)=logP(X,Z)−logP(Z∣X)=logq(Z)P(X,Z)−logq(Z)P(Z∣X)=∫Zq(Z)logq(Z)P(X,Z)dZ−∫Zq(Z)logq(Z)P(Z∣X)dZ=ELBO+KL(q(Z)∣∣P(Z∣X))=L(q)+KL(q(Z)∣∣P(Z∣X))
经过一系列变形,得到
E
B
L
O
+
K
L
EBLO+KL
EBLO+KL 的形式,这里我们将
E
L
B
O
ELBO
ELBO 记为
L
(
q
)
\mathcal{L}(q)
L(q) ,就是所谓的变分。
我们是要求的是后验
P
(
Z
∣
X
)
P(Z|X)
P(Z∣X) ,如果其与
q
(
Z
)
q(Z)
q(Z) 的 KL 散度接近0,那么就能用
q
(
Z
)
q(Z)
q(Z) 来对其进行近似。而等式左边
log
P
(
X
)
\log P(X)
logP(X) 与
Z
Z
Z 无关,因此
E
L
B
O
+
K
L
ELBO+KL
ELBO+KL 在
q
(
Z
)
q(Z)
q(Z) 变化时是个定值,因此,要让 KL 尽量小就转换为让 ELBO 尽量大,即有:
q
^
(
Z
)
=
arg
max
q
(
Z
)
L
(
q
)
→
q
(
Z
)
≈
P
(
Z
∣
X
)
\hat{q}(Z)=\arg\max_{q(Z)}\mathcal{L}(q)\ \ \ \ \rightarrow\ \ \ \ q(Z)\approx P(Z|X)
q^(Z)=argq(Z)maxL(q) → q(Z)≈P(Z∣X)
接下来,我们根据平均场理论,将
q
(
Z
)
q(Z)
q(Z) 划分为
M
M
M 个相互独立的份:
q
(
Z
)
=
∏
i
=
1
M
q
i
(
Z
i
)
q(Z)=\prod_{i=1}^Mq_i(Z_i)
q(Z)=i=1∏Mqi(Zi)
之后在求解的时候,我们会先固定
q
1
,
q
2
,
…
,
q
j
−
1
,
…
,
q
M
q_1,q_2,\dots,q_{j-1},\dots,q_M
q1,q2,…,qj−1,…,qM ,然后求解单个分量
q
j
q_j
qj ,最后将所有分量连乘起来,得到完整的
q
(
Z
)
q(Z)
q(Z) 。
首先先将
q
(
Z
)
q(Z)
q(Z) 代回到原式中:
L
(
q
)
=
∫
Z
q
(
Z
)
log
P
(
X
,
Z
)
d
Z
−
∫
Z
log
q
(
Z
)
d
Z
=
①
−
②
\mathcal{L}(q)=\int_Zq(Z)\log P(X,Z)dZ-\int_Z\log q(Z)dZ=①-②\\
L(q)=∫Zq(Z)logP(X,Z)dZ−∫Zlogq(Z)dZ=①−②
一项一项地来看:
①
=
∫
Z
q
(
Z
)
log
P
(
X
,
Z
)
d
Z
=
∫
Z
∏
i
=
1
M
q
i
(
Z
i
)
log
P
(
X
,
Z
)
d
Z
=
∫
Z
j
q
j
(
Z
j
)
∫
Z
i
(
i
≠
j
)
∏
i
≠
j
M
q
i
(
Z
i
)
log
P
(
X
,
Z
)
d
Z
i
(
i
≠
j
)
d
Z
j
=
∫
Z
j
q
j
(
Z
j
)
∫
Z
i
(
i
≠
j
)
log
P
(
X
,
Z
)
∏
i
≠
j
M
q
i
(
Z
i
)
d
Z
i
(
i
≠
j
)
d
Z
j
=
∫
Z
j
q
j
(
Z
j
)
⋅
E
∏
i
≠
j
M
q
i
(
Z
i
)
[
log
P
(
X
,
Z
)
]
d
Z
j
\begin{align} ①&=\int_Zq(Z)\log P(X,Z)dZ\\ &=\int_Z\prod_{i=1}^Mq_i(Z_i)\log P(X,Z)dZ\\ &=\int_{Z_j}q_j(Z_j)\int_{Z_i(i\ne j)}\prod_{i\ne j}^Mq_i(Z_i)\log P(X,Z)dZ_{i(i\ne j)}dZ_j\\ &=\int_{Z_j}q_j(Z_j)\int_{Z_i(i\ne j)}\log P(X,Z)\prod_{i\ne j}^Mq_i(Z_i)dZ_{i(i\ne j)}dZ_j\\ &=\int_{Z_j}q_j(Z_j)\cdot\mathbb{E}_{\prod_{i\ne j}^Mq_i(Z_i)}[\log P(X,Z)]dZ_j \end{align}
①=∫Zq(Z)logP(X,Z)dZ=∫Zi=1∏Mqi(Zi)logP(X,Z)dZ=∫Zjqj(Zj)∫Zi(i=j)i=j∏Mqi(Zi)logP(X,Z)dZi(i=j)dZj=∫Zjqj(Zj)∫Zi(i=j)logP(X,Z)i=j∏Mqi(Zi)dZi(i=j)dZj=∫Zjqj(Zj)⋅E∏i=jMqi(Zi)[logP(X,Z)]dZj
- 先将 q ( Z ) q(Z) q(Z) 进行拆分为 M M M 份;
- 然后将第 j j j 份拆出来;
- 其他份的积分写成期望的形式(见到积分,就考虑能写成期望)
然后看后面一项:
②
=
∫
Z
q
(
Z
)
log
q
(
Z
)
d
Z
=
∫
Z
∏
i
=
1
M
q
i
(
Z
i
)
log
∏
i
=
1
M
q
i
(
Z
i
)
d
Z
=
∫
Z
∏
i
=
1
M
q
i
(
Z
i
)
∑
i
=
1
M
log
q
i
(
Z
i
)
d
Z
=
∫
Z
∏
i
=
1
M
q
i
(
Z
i
)
[
log
q
1
(
Z
1
)
+
log
q
2
(
Z
2
)
+
⋯
+
log
q
M
(
Z
M
)
]
d
Z
\begin{align} ②&=\int_Zq(Z)\log q(Z)dZ\\ &=\int_Z\prod_{i=1}^Mq_i(Z_i)\log\prod_{i=1}^M q_i(Z_i)dZ\\ &=\int_Z\prod_{i=1}^Mq_i(Z_i)\sum_{i=1}^M\log q_i(Z_i)dZ\\ &=\int_Z\prod_{i=1}^Mq_i(Z_i)[\log q_1(Z_1)+\log q_2(Z_2)+\dots+\log q_M(Z_M)]dZ\\ \end{align}
②=∫Zq(Z)logq(Z)dZ=∫Zi=1∏Mqi(Zi)logi=1∏Mqi(Zi)dZ=∫Zi=1∏Mqi(Zi)i=1∑Mlogqi(Zi)dZ=∫Zi=1∏Mqi(Zi)[logq1(Z1)+logq2(Z2)+⋯+logqM(ZM)]dZ
- 写成 M M M 份;
- log 里面乘变外面加;
- 把连加号写开;
- 然后我们看其中一项(比如第一项):
∫ Z ∏ i = 1 M q i ( Z i ) ⋅ log q 1 ( Z 1 ) d Z = ∫ Z q 1 ( Z 1 ) q 2 ( Z 2 ) … q M ( Z M ) log q 1 ( Z 1 ) d Z = ∫ Z 1 Z 2 … Z M q 1 ( Z 1 ) q 2 ( Z 2 ) … q M ( Z M ) log q 1 ( Z 1 ) d Z 1 d Z 2 … d Z M = ∫ Z 1 q 1 ( Z 1 ) log q 1 ( Z 1 ) d Z 1 ∏ i = 2 M ∫ Z i q i ( Z i ) d Z i = ∫ Z 1 q 1 ( Z 1 ) log q 1 ( Z 1 ) d Z 1 \begin{align} \int_Z\prod_{i=1}^Mq_i(Z_i)\cdot\log q_1(Z_1)dZ&=\int_Zq_1(Z_1)q_2(Z_2)\dots q_M(Z_M)\log q_1(Z_1)dZ\\ &=\int_{Z_1Z_2\dots Z_M}q_1(Z_1)q_2(Z_2)\dots q_M(Z_M)\log q_1(Z_1)dZ_1dZ_2\dots dZ_M\\ &=\int_{Z_1}q_1(Z_1)\log q_1(Z_1)dZ_1\prod_{i=2}^M\int_{Z_i}q_i(Z_i)dZ_i\\ &=\int_{Z_1}q_1(Z_1)\log q_1(Z_1)dZ_1 \end{align} ∫Zi=1∏Mqi(Zi)⋅logq1(Z1)dZ=∫Zq1(Z1)q2(Z2)…qM(ZM)logq1(Z1)dZ=∫Z1Z2…ZMq1(Z1)q2(Z2)…qM(ZM)logq1(Z1)dZ1dZ2…dZM=∫Z1q1(Z1)logq1(Z1)dZ1i=2∏M∫Ziqi(Zi)dZi=∫Z1q1(Z1)logq1(Z1)dZ1
- 把 q 1 ( Z 1 ) q_1(Z_1) q1(Z1) 相关的移到一起;
- 剩下的积分全都是 1
② = ∑ i = 1 M ∫ Z i q i ( Z i ) log q i ( Z i ) d Z i = ∫ Z j q j ( Z j ) log q j ( Z j ) d Z j + C \begin{align} ②&=\sum_{i=1}^M\int_{Z_i}q_i(Z_i)\log q_i(Z_i)dZ_i\\ &=\int_{Z_j}q_j(Z_j)\log q_j(Z_j)dZ_j+C\\ \end{align} ②=i=1∑M∫Ziqi(Zi)logqi(Zi)dZi=∫Zjqj(Zj)logqj(Zj)dZj+C
- 有了 i = 1 i=1 i=1 时的表示,我们就把整个第二项写成连加的形式;
- 我们只关心第 j j j 项,其余的视作常数 C C C
这样处理完两项,有:
①
−
②
=
∫
Z
j
q
j
(
Z
j
)
⋅
E
∏
i
≠
j
M
q
i
(
Z
i
)
[
log
P
(
X
,
Z
)
]
d
Z
j
−
∫
Z
j
q
j
(
Z
j
)
log
q
j
(
Z
j
)
d
Z
j
+
C
=
∫
Z
j
q
j
(
Z
j
)
⋅
log
P
^
(
X
,
Z
j
)
d
Z
j
−
∫
Z
j
q
j
(
Z
j
)
log
q
j
(
Z
j
)
d
Z
j
+
C
=
∫
Z
j
q
j
(
Z
j
)
⋅
log
P
^
(
X
,
Z
j
)
q
j
(
Z
j
)
d
Z
j
=
−
K
L
(
P
^
(
X
,
Z
j
)
∣
∣
q
j
(
Z
j
)
)
≤
0
\begin{align} ①-②&=\int_{Z_j}q_j(Z_j)\cdot\mathbb{E}_{\prod_{i\ne j}^Mq_i(Z_i)}[\log P(X,Z)]dZ_j-\int_{Z_j}q_j(Z_j)\log q_j(Z_j)dZ_j+C\\ &=\int_{Z_j}q_j(Z_j)\cdot\log \hat{P}(X,Z_j) dZ_j-\int_{Z_j}q_j(Z_j)\log q_j(Z_j)dZ_j+C\\ &=\int_{Z_j}q_j(Z_j)\cdot\log\frac{ \hat{P}(X,Z_j)}{q_j(Z_j)}dZ_j\\ &=-KL(\hat{P}(X,Z_j)||q_j(Z_j))\le 0 \end{align}
①−②=∫Zjqj(Zj)⋅E∏i=jMqi(Zi)[logP(X,Z)]dZj−∫Zjqj(Zj)logqj(Zj)dZj+C=∫Zjqj(Zj)⋅logP^(X,Zj)dZj−∫Zjqj(Zj)logqj(Zj)dZj+C=∫Zjqj(Zj)⋅logqj(Zj)P^(X,Zj)dZj=−KL(P^(X,Zj)∣∣qj(Zj))≤0
- 将 ① 中的期望写成一个函数的形式: P ^ ( X , Z j ) \hat{P}(X,Z_j) P^(X,Zj) ;
- 最后就是一个负的 KL 散度,当 P ^ ( X , Z j ) = q j ( Z j ) \hat{P}(X,Z_j)=q_j(Z_j) P^(X,Zj)=qj(Zj) 时取到等号