#! https://zhuanlan.zhihu.com/p/402301009
EM算法详解
1.1 EM算法概括
我们从最大后验出发,根据最大后验的原理,我们求参数的方法为
θ
M
L
E
=
a
r
g
m
a
x
a
P
(
x
∣
θ
)
\theta_{MLE}=\underset{a}{argmax}P(x|\theta)
θMLE=aargmaxP(x∣θ),EM算法的迭代格式为
θ
(
t
+
1
)
=
a
r
g
m
a
x
θ
∫
z
log
P
(
x
,
z
∣
θ
)
⋅
P
(
z
∣
x
,
θ
(
t
)
)
d
z
=
a
r
g
m
a
x
θ
E
z
∣
x
,
θ
(
t
)
[
log
P
(
x
,
z
∣
θ
)
]
\begin{aligned}\theta^{(t+1)}&=\underset{\theta}{argmax}\int_z\log P(x, z|\theta)\cdot P(z|x, \theta^{(t)})dz\\ &=\underset{\theta}{argmax}E_{z|x, \theta^{(t)}}[\log P(x, z|\theta)] \end{aligned}
θ(t+1)=θargmax∫zlogP(x,z∣θ)⋅P(z∣x,θ(t))dz=θargmaxEz∣x,θ(t)[logP(x,z∣θ)]
1.2 证明收敛性
我们上面给出了迭代式,我们接下来就要证明这个迭代式的收敛性,即证明:
log
P
(
x
,
θ
(
t
)
)
≤
log
P
(
x
,
θ
(
t
+
1
)
)
\log P\left(x, \theta^{(t)}\right) \leq \log P\left(x, \theta^{(t+1)}\right)
logP(x,θ(t))≤logP(x,θ(t+1))
p
f
pf
pf:因为
log
p
(
x
∣
θ
)
=
log
p
(
x
,
z
∣
θ
)
−
log
p
(
z
∣
x
,
θ
)
\log p(x \mid \theta)=\log p(x, z \mid \theta)-\log p(z \mid x, \theta)
logp(x∣θ)=logp(x,z∣θ)−logp(z∣x,θ)
对两边关于
z
z
z积分:
∫
z
p
(
z
∣
x
,
θ
(
t
)
)
log
p
(
x
∣
θ
)
d
z
=
∫
z
p
(
z
∣
x
,
θ
(
t
)
)
log
p
(
x
,
z
∣
θ
)
d
z
−
∫
z
p
(
z
∣
x
,
θ
(
t
)
)
log
p
(
z
∣
x
,
θ
)
d
z
log
p
(
x
∣
θ
)
=
Q
(
θ
,
θ
(
t
)
)
+
H
(
θ
,
θ
(
t
)
)
\begin{aligned}\int_{z} p(z \mid x, \theta ^{(t)}) \log p(x \mid \theta) d z&=\int_{z} p(z \mid x, \theta^{(t)}) \log p(x, z\mid\theta) d z - \int_z p(z \mid x, \theta^{(t)})\log p(z\mid x, \theta) d z \\ \log p(x \mid \theta)&=Q(\theta, \theta^{(t)})+H(\theta, \theta^{(t)}) \end{aligned}
∫zp(z∣x,θ(t))logp(x∣θ)dzlogp(x∣θ)=∫zp(z∣x,θ(t))logp(x,z∣θ)dz−∫zp(z∣x,θ(t))logp(z∣x,θ)dz=Q(θ,θ(t))+H(θ,θ(t))
其中
Q
(
θ
,
θ
(
t
)
)
=
∫
z
p
(
z
∣
x
,
θ
(
t
)
)
log
p
(
x
,
z
∣
θ
)
d
z
Q(\theta, \theta^{(t)})=\int_{z} p(z \mid x, \theta^{(t)}) \log p(x, z\mid\theta) d z
Q(θ,θ(t))=∫zp(z∣x,θ(t))logp(x,z∣θ)dz,
H
(
θ
,
θ
(
t
)
)
=
−
∫
z
p
(
z
∣
x
,
θ
(
t
)
)
log
p
(
z
∣
x
,
θ
)
d
z
H(\theta, \theta^{(t)})=-\int_z p(z \mid x, \theta^{(t)})\log p(z\mid x, \theta) d z
H(θ,θ(t))=−∫zp(z∣x,θ(t))logp(z∣x,θ)dz
根据上面的
θ
\theta
θ定义显然有
Q
(
θ
(
t
)
,
θ
(
t
)
)
≤
Q
(
θ
(
t
+
1
)
,
θ
(
t
)
)
Q(\theta^{(t)}, \theta^{(t)}) \le Q(\theta^{(t+1)}, \theta^{(t)})
Q(θ(t),θ(t))≤Q(θ(t+1),θ(t)),下面我们来证明
H
(
θ
(
t
+
1
)
,
θ
(
t
)
)
≤
H
(
θ
(
t
)
,
θ
(
t
)
)
H(\theta^{(t+1)}, \theta^{(t)}) \le H(\theta^{(t)}, \theta^{(t)})
H(θ(t+1),θ(t))≤H(θ(t),θ(t))
H
(
θ
(
t
+
1
)
,
θ
(
t
)
)
−
H
(
θ
(
t
)
,
θ
(
t
)
)
=
∫
z
p
(
z
∣
x
,
θ
(
t
)
)
log
p
(
z
∣
x
,
θ
(
t
)
)
d
z
−
∫
z
p
(
z
∣
x
,
θ
(
t
)
)
log
p
(
z
∣
x
,
θ
(
t
+
1
)
)
d
z
=
∫
z
p
(
z
∣
x
,
θ
(
t
)
)
log
p
(
z
∣
x
,
θ
(
t
)
)
p
(
z
∣
x
,
θ
(
t
+
1
)
)
d
z
=
−
K
L
(
p
(
z
∣
x
,
θ
(
t
)
)
∥
p
(
z
∣
x
,
θ
(
t
+
1
)
)
)
≤
0
\begin{aligned} H\left(\theta^{(t+1)}, \theta^{(t)}\right)-H\left(\theta^{(t)}, \theta^{(t)}\right)&=\int_{z} p( z \mid x, \theta^{(t)}) \log p(z| x , \theta^{(t)})dz-\int_{z} p( z \mid x, \theta^{(t)}) \log p(z| x , \theta^{(t+1)})dz \\ &=\int_{z} p( z \mid x, \theta^{(t)})\log \frac{p(z| x , \theta^{(t)})}{p(z| x , \theta^{(t+1)})}dz \\ &=-KL(p( z \mid x, \theta^{(t)}) \| p( z \mid x, \theta^{(t+1)})) \\ &\le 0 \end{aligned}
H(θ(t+1),θ(t))−H(θ(t),θ(t))=∫zp(z∣x,θ(t))logp(z∣x,θ(t))dz−∫zp(z∣x,θ(t))logp(z∣x,θ(t+1))dz=∫zp(z∣x,θ(t))logp(z∣x,θ(t+1))p(z∣x,θ(t))dz=−KL(p(z∣x,θ(t))∥p(z∣x,θ(t+1)))≤0
证得结论成立,该迭代式可行。
2 EM算法
下面我们给出EM算法的具体步骤
假设
X
X
X是观察数据,
Z
Z
Z是隐变量,
(
X
,
Z
)
(X, Z)
(X,Z)是完全数据。
E
s
t
e
p
:
根据已知参数求最大期望
E
z
∣
x
,
θ
(
t
)
[
log
P
(
x
,
z
∣
θ
(
t
)
)
]
M
s
t
e
p
:
θ
(
t
+
1
)
=
a
r
g
m
a
x
a
E
z
∣
x
,
θ
(
t
)
[
log
P
(
x
,
z
∣
θ
(
t
)
)
]
E \ step:根据已知参数求最大期望E_{z|x, \theta^{(t)}}[\log P(x, z|\theta^{(t)})] \\ M \ step: \theta^{(t+1)}=\underset{a}{argmax}E_{z|x, \theta^{(t)}}[\log P(x, z|\theta^{(t)})]
E step:根据已知参数求最大期望Ez∣x,θ(t)[logP(x,z∣θ(t))]M step:θ(t+1)=aargmaxEz∣x,θ(t)[logP(x,z∣θ(t))]
3 EM算法推导
关于最大后验,有结论
log
P
(
x
∣
θ
)
=
log
P
(
x
,
z
∣
θ
)
−
log
P
(
z
∣
x
,
θ
)
=
log
P
(
x
,
z
∣
θ
)
q
(
z
)
−
log
P
(
z
∣
x
,
θ
)
q
(
z
)
\begin{aligned}\log P(x \mid \theta)&=\log P(x, z \mid \theta)-\log P(z \mid x, \theta) \\ &=\log \frac{P(x, z \mid \theta)}{q(z)}-\log \frac{P(z \mid x, \theta)}{q(z)}\end{aligned}
logP(x∣θ)=logP(x,z∣θ)−logP(z∣x,θ)=logq(z)P(x,z∣θ)−logq(z)P(z∣x,θ)
两边仍关于
z
z
z求积分
∫
z
q
(
z
)
log
P
(
x
∣
θ
)
d
z
=
∫
z
q
(
z
)
log
P
(
x
,
z
∣
θ
)
q
(
z
)
d
z
−
∫
z
q
(
z
)
log
P
(
z
∣
x
,
θ
)
q
(
z
)
d
z
log
P
(
x
∣
θ
)
=
E
L
B
O
+
K
L
(
q
(
z
)
∥
P
(
z
∣
x
,
θ
)
)
\begin{aligned} \int_z q(z)\log P(x \mid \theta)dz &= \int_z q(z)\log \frac{P(x, z \mid \theta)}{q(z)}dz - \int_z q(z)\log \frac{P(z \mid x, \theta)}{q(z)}dz \\ \log P(x \mid \theta) &= ELBO + KL(q(z) \|P(z|x, \theta)) \end{aligned}
∫zq(z)logP(x∣θ)dzlogP(x∣θ)=∫zq(z)logq(z)P(x,z∣θ)dz−∫zq(z)logq(z)P(z∣x,θ)dz=ELBO+KL(q(z)∥P(z∣x,θ))
EM算法的思想就是通过迭代使
E
L
B
O
ELBO
ELBO变大,从而提升
log
l
i
k
e
l
i
h
o
o
d
\log likelihood
loglikelihood(对数似然)。即
θ
^
=
a
r
g
m
a
x
θ
E
L
B
O
=
a
r
g
m
a
x
θ
∫
z
q
(
z
)
log
P
(
x
,
z
∣
θ
)
q
(
z
)
d
z
=
a
r
g
m
a
x
θ
∫
z
p
(
z
∣
x
,
θ
(
t
)
)
log
P
(
x
,
z
∣
θ
)
p
(
z
∣
x
,
θ
(
t
)
)
d
z
=
a
r
g
m
a
x
θ
∫
z
p
(
z
∣
x
,
θ
(
t
)
)
log
P
(
x
,
z
∣
θ
)
d
z
\begin{aligned} \hat{\theta}&=\underset{\theta}{argmax}ELBO \\ &=\underset{\theta}{argmax}\int_zq(z)\log \frac{P(x, z \mid \theta)}{q(z)}dz \\ &=\underset{\theta}{argmax}\int_z p(z \mid x, \theta^{(t)}) \log \frac{P(x, z \mid \theta)}{p(z \mid x, \theta^{(t)})}dz \\ &=\underset{\theta}{argmax}\int_z p(z \mid x, \theta^{(t)}) \log P(x, z \mid \theta)dz \end{aligned}
θ^=θargmaxELBO=θargmax∫zq(z)logq(z)P(x,z∣θ)dz=θargmax∫zp(z∣x,θ(t))logp(z∣x,θ(t))P(x,z∣θ)dz=θargmax∫zp(z∣x,θ(t))logP(x,z∣θ)dz
通过最大后验的角度我们就推导出了
E
M
EM
EM算法的迭代式,就是不断提高
E
L
B
O
ELBO
ELBO。
4 EM算法推导的另外一个角度
这里我们再介绍一个推导出来
E
M
EM
EM算法的另外一个角度——利用Jensen不等式。关于Jensen不等式大家可以看知乎上的一篇回答
Jensen不等式。
log
P
(
x
∣
θ
)
=
log
∫
z
P
(
x
,
z
∣
θ
)
d
z
=
log
∫
z
q
(
z
)
⋅
P
(
x
,
z
∣
θ
)
q
(
z
)
d
z
=
log
E
q
(
z
)
[
P
(
x
,
z
∣
θ
)
q
(
z
)
]
≥
E
q
(
z
)
[
log
P
(
x
,
z
∣
θ
)
q
(
z
)
]
=
E
L
B
O
\begin{aligned} \log P(x|\theta)&=\log \int_z P(x, z|\theta)dz \\ &=\log \int_z q(z) \cdot \frac{P(x, z|\theta)}{q(z)}dz \\ &= \log E_{q(z)}[\frac{P(x, z|\theta)}{q(z)}] \\ &\ge E_{q(z)}[\log \frac{P(x, z|\theta)}{q(z)}] \\ &= ELBO \end{aligned}
logP(x∣θ)=log∫zP(x,z∣θ)dz=log∫zq(z)⋅q(z)P(x,z∣θ)dz=logEq(z)[q(z)P(x,z∣θ)]≥Eq(z)[logq(z)P(x,z∣θ)]=ELBO
根据Jensen不等式可知,只有当
P
(
x
,
z
∣
θ
)
q
(
z
)
\frac{P(x, z|\theta)}{q(z)}
q(z)P(x,z∣θ)为常数时,等号才成立。
此时
P
(
x
,
z
∣
θ
)
q
(
z
)
=
C
q
(
z
)
=
P
(
x
,
z
∣
θ
)
C
∫
z
q
(
z
)
d
z
=
∫
z
P
(
x
,
z
∣
θ
)
C
d
z
1
=
1
C
P
(
x
∣
θ
)
P
(
x
∣
θ
)
=
C
\begin{aligned} \frac{P(x, z|\theta)}{q(z)}&=C \\ q(z)&=\frac{P(x, z|\theta)}{C} \\ \int_z q(z)dz &= \int_z \frac{P(x, z|\theta)}{C} dz \\ 1 &= \frac{1}{C} P(x|\theta) \\ P(x|\theta) &= C \end{aligned}
q(z)P(x,z∣θ)q(z)∫zq(z)dz1P(x∣θ)=C=CP(x,z∣θ)=∫zCP(x,z∣θ)dz=C1P(x∣θ)=C
所以
q
(
z
)
=
P
(
x
,
z
∣
θ
)
P
(
x
∣
θ
)
q(z) = \frac{P(x, z|\theta)}{P(x|\theta)}
q(z)=P(x∣θ)P(x,z∣θ)
这就证明了
q
(
z
)
q(z)
q(z)是关于
z
z
z的后验,符合上一个推导的结果。
5 关于EM算法
- EM算法解决的是概率生成模型的问题
- 狭义EM算法可以推导到广义EM算法
根据上面的推导,我们有下面的结论
log P ( x ∣ θ ) = E L B O + K L ( q ∥ p ) \log P(x|\theta) = ELBO + KL(q\|p) logP(x∣θ)=ELBO+KL(q∥p)
其中
{ E L B O = E q ( z ) [ log P ( x , z ∣ θ ) q ( z ) ] K L ( q ∥ p ) = ∫ z q ( z ) log q ( z ) p ( z ∣ x , θ ) d z \left\{\begin{array}{l} E L B O=E_{q(z)} [ \log \frac{P(x, z \mid \theta)}{q(z)}] \\ K L\left(q \| p\right)=\int_{z} q(z) \log \frac{q(z)}{p(z | x, \theta)} d z \end{array}\right. {ELBO=Eq(z)[logq(z)P(x,z∣θ)]KL(q∥p)=∫zq(z)logp(z∣x,θ)q(z)dz
所以
log P ( x ∣ θ ) ≥ E L B O = L ( q , θ ) \log P(x|\theta) \ge ELBO=\mathcal{L}(q, \theta) logP(x∣θ)≥ELBO=L(q,θ)
在狭义EM算法中, q ( z ) = P ( z ∣ x , θ ( t ) ) q(z)=P(z|x, \theta^{(t)}) q(z)=P(z∣x,θ(t))。有时 P ( z ∣ x , θ ( t ) ) P(z|x, \theta^{(t)}) P(z∣x,θ(t))也会不易求得(此时可以用近似推断方法来求,包括变分推断、MCMC采样),所以此时不能直接令 q ( z ) = P ( z ∣ x , θ ( t ) ) q(z)=P(z|x, \theta^{(t)}) q(z)=P(z∣x,θ(t)),此时就是广义EM算法的形式:
在形式上, θ ^ \hat{\theta} θ^固定时, q ^ = a r g m a x q K L ( q ∥ p ) = a r m a x q L ( q ) \hat{q} = \underset{q}{argmax}KL(q\|p)=\underset{q}{armax}\mathcal{L}(q) q^=qargmaxKL(q∥p)=qarmaxL(q)
q ^ \hat{q} q^固定时, θ ^ = a r g m a x θ L ( q ^ ) \hat{\theta}=\underset{\theta}{argmax}\mathcal{L}(\hat{q}) θ^=θargmaxL(q^)
广义EM算法:
{ E − s t e p : q ( t + 1 ) = a r m a x q L ( q , θ ( t ) ) M − s t e p : θ ( t + 1 ) = a r g m a x θ L ( q ( t + 1 ) , θ ) \left\{\begin{array}{l} E-step: q^{(t+1)}=\underset{q}{armax}\mathcal{L}(q, \theta^{(t)}) \\ M-step: \theta^{(t+1)}=\underset{\theta}{argmax}\mathcal{L}(q^{(t+1)}, \theta) \end{array}\right. ⎩ ⎨ ⎧E−step:q(t+1)=qarmaxL(q,θ(t))M−step:θ(t+1)=θargmaxL(q(t+1),θ)
观察广义EM算法里 E L B O ELBO ELBO的形式:
L ( q , θ ) = E q [ log P ( x , z ) − log q ( z ) ] = E q [ log P ( x , z ) ] − E q [ log q ( z ) ] = E q [ log P ( x , z ) ] + H ( q ) \begin{aligned} \mathcal{L}(q, \theta) &= E_{q}[\log P(x, z) - \log q(z)] \\ &=E_{q}[\log P(x, z)] - E_{q}[\log q(z)] \\ &=E_{q}[\log P(x, z)] + H(q) \end{aligned} L(q,θ)=Eq[logP(x,z)−logq(z)]=Eq[logP(x,z)]−Eq[logq(z)]=Eq[logP(x,z)]+H(q)
其中 H ( q ) = − ∫ z q ( z ) log q ( z ) d z H(q)=-\int_z q(z)\log q(z)dz H(q)=−∫zq(z)logq(z)dz是关于 q q q分布的熵。
在狭义EM算法中我们可以观察到是没有第二项的,是因为在狭义EM算法中 q q q分布是确定的, H ( q ) = 0 H(q)=0 H(q)=0
6 EM算法的变种
前面我们提到在狭义EM算法中, q ( z ) = P ( z ∣ x , θ ( t ) ) q(z)=P(z|x, \theta^{(t)}) q(z)=P(z∣x,θ(t))。有时 P ( z ∣ x , θ ( t ) ) P(z|x, \theta^{(t)}) P(z∣x,θ(t))也会不易求得(此时可以用近似推断方法来求,包括变分推断、MCMC采样),所以此时不能直接令 q ( z ) = P ( z ∣ x , θ ( t ) ) q(z)=P(z|x, \theta^{(t)}) q(z)=P(z∣x,θ(t)),这就分别有EM算法的变种——VEM, MCEM。