EM 算法
1、问题描述
极大似然法是最常用的参数估计方法之一。设观测变量为
x
x
x,模型参数为
θ
\theta
θ ,则极大似然法通过最大化似然函数
p
(
x
∣
θ
)
p(x|\theta)
p(x∣θ) 或对数似然
log
p
(
x
∣
θ
)
\log p(x|\theta)
logp(x∣θ) 来求解最优的
θ
\theta
θ. 然而在一些问题中,观测变量
x
x
x 依赖于隐变量
z
z
z. 此时根据全概率公式有:
p
(
x
∣
θ
)
=
∑
z
p
(
x
,
z
∣
θ
)
o
r
p
(
x
∣
θ
)
=
∫
z
p
(
x
,
z
∣
θ
)
d
z
p(x|\theta)=\sum_z p(x,z|\theta) \quad or \quad p(x|\theta)=\int_z p(x,z|\theta)dz
p(x∣θ)=z∑p(x,z∣θ)orp(x∣θ)=∫zp(x,z∣θ)dz
对数似然为:
L
(
θ
)
=
log
p
(
x
∣
θ
)
=
log
∑
z
p
(
x
,
z
∣
θ
)
L(\theta)=\log p(x|\theta)=\log \sum_z p(x,z|\theta)
L(θ)=logp(x∣θ)=logz∑p(x,z∣θ)
如果仍然使用极大似然法,我们会发现
L
(
θ
)
L(\theta)
L(θ) 的导数将变得非常复杂,要优化的参数之间无法分离,导致无法写出封闭形式的解。这时就需要用到 EM 算法了。
2、理论推导
首先介绍 Jensen 不等式:
若
f
f
f 是凸函数,则有:
E
(
f
(
X
)
)
≥
f
(
E
(
X
)
)
E(f(X))\ge f(E(X))
E(f(X))≥f(E(X))
当随机变量
X
X
X 为常数时等号成立。
若
f
f
f 是凹函数,则结论相反。
回到上面的问题,我们为隐变量
z
z
z 引入一个概率分布
q
(
z
)
q(z)
q(z) ,则有:
L
(
θ
)
=
log
∑
z
p
(
x
,
z
∣
θ
)
=
log
(
∑
z
q
(
z
)
p
(
x
,
z
∣
θ
)
q
(
z
)
)
≥
∑
z
q
(
z
)
log
p
(
x
,
z
∣
θ
)
q
(
z
)
=
J
(
θ
,
q
)
L(\theta)=\log \sum_z p(x,z|\theta)=\log \left(\sum_z q(z)\frac{p(x,z|\theta)}{q(z)}\right)\ge \sum_z q(z)\log\frac{p(x,z|\theta)}{q(z)} =J(\theta,q)
L(θ)=logz∑p(x,z∣θ)=log(z∑q(z)q(z)p(x,z∣θ))≥z∑q(z)logq(z)p(x,z∣θ)=J(θ,q)
(不等号使用了 Jensen不等式。函数
f
f
f 是 log 函数,凹函数)
当
q
(
z
)
=
p
(
z
∣
x
,
θ
)
q(z)=p(z|x,\theta)
q(z)=p(z∣x,θ) 时,等号成立,因此
J
(
θ
,
q
)
J(\theta,q)
J(θ,q) 是
L
(
θ
)
L(\theta)
L(θ) 紧的下界,所以可以将最大化
L
(
θ
)
L(\theta)
L(θ) 的问题变为最大化
J
(
θ
,
q
)
J(\theta,q)
J(θ,q):
max
θ
L
(
θ
)
⇔
max
θ
,
q
J
(
θ
,
q
)
\max_\theta L(\theta)\Leftrightarrow\max_{\theta,q} J(\theta,q)
θmaxL(θ)⇔θ,qmaxJ(θ,q)
不过,同时优化
θ
\theta
θ 和
q
q
q 依旧非常困难,因此 EM 算法采用交替迭代的方式优化:
-
E-step 固定 $\theta $ 优化 q q q:
q t + 1 = arg max q J ( θ t , q ) q^{t+1}=\argmax_q J(\theta^t,q) qt+1=qargmaxJ(θt,q)
由上面取等号条件可知:$q{t+1}=p(z|x,\thetat) $ -
M-step 固定 q q q 优化 $\theta $:
θ t + 1 = arg max θ J ( θ , q t + 1 ) \theta^{t+1}=\argmax_\theta J(\theta,q^{t+1}) θt+1=θargmaxJ(θ,qt+1)
上述优化目标可以拆出一个与 θ \theta θ 无关的常数项:J ( θ , q t + 1 ) = ∑ z q t + 1 ( z ) log p ( x , z ∣ θ ) − ∑ z q t + 1 ( z ) log q t + 1 ( z ) = ∑ z p ( z ∣ x , θ t ) log p ( x , z ∣ θ ) − ∑ z q t + 1 ( z ) log q t + 1 ( z ) \begin{align*} J(\theta,q^{t+1})&=\sum_z q^{t+1}(z)\log p(x,z|\theta)-\sum_z q^{t+1}(z)\log q^{t+1}(z) \\ &=\sum_z p(z|x,\theta^t)\log p(x,z|\theta)-\sum_z q^{t+1}(z)\log q^{t+1}(z) \end{align*} J(θ,qt+1)=z∑qt+1(z)logp(x,z∣θ)−z∑qt+1(z)logqt+1(z)=z∑p(z∣x,θt)logp(x,z∣θ)−z∑qt+1(z)logqt+1(z)
令 Q ( θ , θ t ) = ∑ z p ( z ∣ x , θ t ) log p ( x , z ∣ θ ) Q(\theta,\theta^t)=\sum_z p(z|x,\theta^t)\log p(x,z|\theta) Q(θ,θt)=∑zp(z∣x,θt)logp(x,z∣θ),因为第二项为常数,因此只优化第一项即可:
θ t + 1 = arg max θ Q ( θ , θ t ) \theta^{t+1}=\argmax_\theta Q(\theta,\theta^t) θt+1=θargmaxQ(θ,θt)
经过一轮 E-step 和 M-step,有:
L
(
θ
t
+
1
)
≥
J
(
q
t
+
1
,
θ
t
+
1
)
≥
J
(
q
t
+
1
,
θ
t
)
=
L
(
θ
t
)
L(\theta^{t+1})\ge J(q^{t+1},\theta^{t+1})\ge J(q^{t+1},\theta^t)=L(\theta^t)
L(θt+1)≥J(qt+1,θt+1)≥J(qt+1,θt)=L(θt)
可知
L
(
θ
)
L(\theta)
L(θ) 确实得到了优化。
3、算法步骤
综上所述,EM算法的步骤如下:
- 随机初始化 θ 0 \theta^0 θ0
- E-step:给定
θ
t
\theta^t
θt,求隐变量的后验分布:
q t + 1 = p ( z ∣ x , θ t ) q^{t+1}=p(z|x,\theta^t) qt+1=p(z∣x,θt) - M-step:
优化 Q ( θ , θ t ) = ∑ z p ( z ∣ x , θ t ) log p ( x , z ∣ θ ) Q(\theta,\theta^t)=\sum_z p(z|x,\theta^t)\log p(x,z|\theta) Q(θ,θt)=z∑p(z∣x,θt)logp(x,z∣θ)
得到: θ t + 1 = arg max θ Q ( θ , θ t ) \theta^{t+1}=\argmax_\theta Q(\theta,\theta^t) θt+1=θargmaxQ(θ,θt) - 迭代执行 2、3 步直至收敛。
4、另一种分析
可以从另一种角度,绕过 Jensen 不等式进行推导 ,如下:
L
(
θ
)
=
log
p
(
x
∣
θ
)
=
∫
z
q
(
z
)
log
p
(
x
∣
θ
)
d
z
=
∫
z
q
(
z
)
log
p
(
x
∣
z
,
θ
)
p
(
z
)
p
(
z
∣
x
,
θ
)
q
(
z
)
q
(
z
)
d
z
=
∫
z
q
(
z
)
log
p
(
x
∣
z
,
θ
)
d
z
−
∫
z
q
(
z
)
log
q
(
z
)
p
(
z
)
d
z
+
∫
z
q
(
z
)
log
q
(
z
)
p
(
z
∣
x
,
θ
)
d
z
=
∫
z
q
(
z
)
log
p
(
x
∣
z
,
θ
)
p
(
z
)
q
(
z
)
d
z
+
K
L
(
q
(
z
)
∣
∣
p
(
z
∣
x
,
θ
)
)
=
∫
z
q
(
z
)
log
p
(
x
,
z
∣
θ
)
d
z
⏟
J
(
θ
,
q
)
+
K
L
(
q
(
z
)
∣
∣
p
(
z
∣
x
,
θ
)
)
⏟
≥
0
\begin{align*} L(\theta)&=\log p(x|\theta)\\ &=\int_z q(z)\log p(x|\theta)dz\\ &=\int_z q(z)\log \frac{p(x|z,\theta)p(z)}{p(z|x,\theta)}\frac{q(z)}{q(z)}dz\\ &=\int_z q(z) \log p(x|z,\theta)dz-\int_z q(z)\log \frac{q(z)}{p(z)}dz+\int_z q(z)\log\frac{q(z)}{p(z|x,\theta)}dz\\ &=\int_z q(z)\log \frac{p(x|z,\theta)p(z)}{q(z)}dz + KL(q(z)||p(z|x,\theta))\\ &=\underbrace{\int_z q(z)\log p(x,z|\theta)dz}_{J(\theta,q)}+\underbrace{KL(q(z)||p(z|x,\theta))}_{\ge 0} \end{align*}
L(θ)=logp(x∣θ)=∫zq(z)logp(x∣θ)dz=∫zq(z)logp(z∣x,θ)p(x∣z,θ)p(z)q(z)q(z)dz=∫zq(z)logp(x∣z,θ)dz−∫zq(z)logp(z)q(z)dz+∫zq(z)logp(z∣x,θ)q(z)dz=∫zq(z)logq(z)p(x∣z,θ)p(z)dz+KL(q(z)∣∣p(z∣x,θ))=J(θ,q)
∫zq(z)logp(x,z∣θ)dz+≥0
KL(q(z)∣∣p(z∣x,θ))
从而得到 J ( θ , q ) ≤ L ( θ ) J(\theta,q)\le L(\theta) J(θ,q)≤L(θ).
这也正是 ELBO 的推导过程。
ELBO 一般形式:
log p ( x ) ≥ E q ( z ) [ log p ( x , z ) q ( z ) ] \log p(x)\ge E_{q(z)}\left[\log \frac{p(x,z)}{q(z)}\right] logp(x)≥Eq(z)[logq(z)p(x,z)]