白板推导系列Pytorch-期望最大(EM)算法
EM算法介绍直接看这篇博客-如何通俗理解EM算法,讲的非常好,里面也有算法的公式推导。当然白板推导的视频里面公式推导已经讲的很清楚了,就是缺少应用实例。这篇博客用三个很通俗的例子引入了极大似然估计和EM算法。美中不足的是并没有详细说明极大似然估计并不是一定陷入鸡生蛋蛋生鸡的循环而没有办法处理隐变量问题,而是由于计算复杂从而摒弃了这个方法。
当我们能知道z的分布的时候,其实也是可以用极大似然估计表示的
但是,很多时候,我们很难获得Z的分布,除非我们事先对Z已经很有了解,比如我们如果能够确定Z是一个伯努利分布(比如三硬币模型),那么对Z的分布估计问题就转化成了一个P参数的估计问题
虽然但是,即便知道z的分布,使用极大似然估计也不一定能求。
下面我们一起看看三硬币模型的例子
三硬币模型
有三枚硬币(ABC)正面向上的概率分别为 π , p , q \pi,p,q π,p,q。进行如下试验——先掷A,如果A正面向上则掷B,如果A反面向上则掷C。如此独立地重复做n次试验;记录B和C的结果,正面向上记为1,观测结果为: 1 , 1 , 0 , 0 , . . . 1,1,0,0,... 1,1,0,0,...
若只能观测到结果,不能观测到掷硬币过程,即每一次的观测结果(1或0)由B或C中的哪枚硬币掷出的是未知的。如此情况下估计三枚硬币正面向上的概率 π , p , q \pi,p,q π,p,q。观测数据表示为 Y = ( Y 1 , Y 2 , . . . , Y n ) T Y = (Y_1,Y_2,...,Y_n)^T Y=(Y1,Y2,...,Yn)T,未观测数据表示为 Z = ( Z 1 , Z 2 , . . . , Z n ) Z=(Z_1,Z_2,...,Z_n) Z=(Z1,Z2,...,Zn)。
一次试验
P
(
y
∣
θ
)
=
π
p
y
(
1
−
p
)
(
1
−
y
)
+
(
1
−
π
)
q
y
(
1
−
q
)
(
1
−
y
)
P(y \mid \theta)=\pi p^{y}(1-p)^{(1-y)}+(1-\pi) q^{y}(1-q)^{(1-y)}
P(y∣θ)=πpy(1−p)(1−y)+(1−π)qy(1−q)(1−y)
n次试验
P ( Y ∣ θ ) = ∑ Z P ( Z ∣ θ ) P ( Y ∣ Z , θ ) = ∏ j = 1 n [ π p y j ( 1 − p ) ( 1 − y j ) + ( 1 − π ) q y j ( 1 − q ) ( 1 − y j ) ] \begin{aligned} P(Y \mid \theta) &=\sum_{Z} P(Z \mid \theta) P(Y \mid Z, \theta) \\ &=\prod_{j=1}^{n}\left[\pi p^{y_{j}}(1-p)^{\left(1-y_{j}\right)}+(1-\pi) q^{y_{j}}(1-q)^{\left(1-y_{j}\right)}\right] \end{aligned} P(Y∣θ)=Z∑P(Z∣θ)P(Y∣Z,θ)=j=1∏n[πpyj(1−p)(1−yj)+(1−π)qyj(1−q)(1−yj)]
极大似然估计
从极大似然的角度,我们显然是要找到最合适的
π
,
p
,
q
\pi,p,q
π,p,q使得
P
(
Y
∣
θ
)
P(Y \mid \theta)
P(Y∣θ)最大。如下
θ
=
a
r
g
m
a
x
θ
l
o
g
P
(
Y
∣
θ
)
=
a
r
g
m
a
x
θ
∑
j
=
1
n
l
o
g
[
π
p
y
j
(
1
−
p
)
(
1
−
y
j
)
+
(
1
−
π
)
q
y
j
(
1
−
q
)
(
1
−
y
j
)
]
\begin{aligned} \theta &= \underset{\theta}{argmax}\ log\ P(Y|\theta) \\ &= \underset{\theta}{argmax}\sum_{j=1}^{n}log\left[\pi p^{y_{j}}(1-p)^{\left(1-y_{j}\right)}+(1-\pi) q^{y_{j}}(1-q)^{\left(1-y_{j}\right)}\right] \end{aligned}
θ=θargmax log P(Y∣θ)=θargmaxj=1∑nlog[πpyj(1−p)(1−yj)+(1−π)qyj(1−q)(1−yj)]
然后如果我们对
π
,
p
,
q
\pi,p,q
π,p,q求偏导
∂
L
∂
π
=
∑
j
=
1
n
p
y
j
(
1
−
p
)
(
1
−
y
j
)
−
q
y
j
(
1
−
q
)
(
1
−
y
j
)
π
p
y
j
(
1
−
p
)
(
1
−
y
j
)
+
(
1
−
π
)
q
y
j
(
1
−
q
)
(
1
−
y
j
)
=
0
\frac{\partial L}{\partial \pi} = \sum_{j=1}^{n}\frac{p^{y_{j}}(1-p)^{\left(1-y_{j}\right)}-q^{y_{j}}(1-q)^{\left(1-y_{j}\right)}}{\pi p^{y_{j}}(1-p)^{\left(1-y_{j}\right)}+(1-\pi) q^{y_{j}}(1-q)^{\left(1-y_{j}\right)}} = 0
∂π∂L=j=1∑nπpyj(1−p)(1−yj)+(1−π)qyj(1−q)(1−yj)pyj(1−p)(1−yj)−qyj(1−q)(1−yj)=0
其它都不用求了,就看上面这个式子,你说这要怎么求?我是没有办法
EM求解
以下内容参考EM算法公式推导 (三硬币模型)
现在我们再来看看EM算法会怎么处理这个事情
我们先把EM算法的公式列出来
E-step
E
=
∫
z
l
o
g
P
(
y
,
z
∣
θ
)
P
(
z
∣
y
,
θ
t
)
d
z
E = \int_z logP(y,z|\theta)\ P(z|y,\theta^t)dz
E=∫zlogP(y,z∣θ) P(z∣y,θt)dz
M-step
θ
t
+
1
=
a
r
g
m
a
x
θ
E
\theta^{t+1} = \underset{\theta}{argmax}\ E
θt+1=θargmax E
对于离散的情况,计算E的时候,我们通常会转化成这样的式子(至于怎么转换的,白板视频中有推导,然后我贴的原博客中也有推导)
∑
j
=
1
N
[
∑
z
j
P
(
z
j
∣
y
j
,
θ
t
)
l
o
g
[
P
(
y
j
,
z
j
∣
θ
)
]
]
\sum_{j=1}^{N}\left[\sum_{z_{j}} P\left(z_{j} \mid y_{j}, \theta^{t}\right)\ log[P\left(y_{j}, z_{j} \mid \theta\right)]\right]
j=1∑N⎣⎡zj∑P(zj∣yj,θt) log[P(yj,zj∣θ)]⎦⎤
然后利用贝叶斯定理和全概率公式得到
P
(
z
j
∣
y
j
,
θ
t
)
=
P
(
y
j
∣
z
j
,
θ
t
)
∗
P
(
z
j
∣
θ
t
)
∑
k
=
1
2
P
(
y
j
∣
z
k
,
θ
t
)
∗
P
(
z
k
∣
θ
t
)
\begin{aligned} P(z_j \mid y_j,\theta^t) = \frac{P(y_j|z_j,\theta^t)*P(z_j|\theta^t)}{\sum_{k=1}^{2}P(y_j|z_k,\theta^t)*P(z_k|\theta^t)} \end{aligned}
P(zj∣yj,θt)=∑k=12P(yj∣zk,θt)∗P(zk∣θt)P(yj∣zj,θt)∗P(zj∣θt)
P ( y j , z j ∣ θ ) = P ( y j ∣ z j , θ ) ∗ P ( z j ∣ θ ) P(y_j,z_j|\theta) = P(y_j|z_j,\theta)*P(z_j|\theta) P(yj,zj∣θ)=P(yj∣zj,θ)∗P(zj∣θ)
定义
P
(
z
j
=
1
∣
y
j
,
θ
t
)
=
p
t
y
j
∗
(
1
−
p
t
1
−
y
j
)
∗
π
t
p
t
y
j
∗
(
1
−
p
t
1
−
y
j
)
∗
π
t
+
q
t
y
j
∗
(
1
−
q
t
1
−
y
j
)
∗
(
1
−
π
t
)
=
μ
j
t
P
(
z
j
=
0
∣
y
j
,
θ
t
)
=
q
t
y
j
∗
(
1
−
q
t
1
−
y
j
)
∗
(
1
−
π
t
)
p
t
y
j
∗
(
1
−
p
t
1
−
y
j
)
∗
π
t
+
q
t
y
j
∗
(
1
−
q
t
1
−
y
j
)
∗
(
1
−
π
t
)
=
1
−
μ
j
t
\begin{aligned} P(z_j=1|y_j,\theta^t) &= \frac{p^{t^{y_j}}*(1-p^{t^{1-y_j}})*\pi^t}{p^{t^{y_j}}*(1-p^{t^{1-y_j}})*\pi^t+q^{t^{y_j}}*(1-q^{t^{1-y_j}})*(1-\pi^t)} = \mu_j^{t} \\ \\ P(z_j=0|y_j,\theta^t) &= \frac{q^{t^{y_j}}*(1-q^{t^{1-y_j}})*(1-\pi^t)}{p^{t^{y_j}}*(1-p^{t^{1-y_j}})*\pi^t+q^{t^{y_j}}*(1-q^{t^{1-y_j}})*(1-\pi^t)} = 1-\mu_j^{t} \end{aligned}
P(zj=1∣yj,θt)P(zj=0∣yj,θt)=ptyj∗(1−pt1−yj)∗πt+qtyj∗(1−qt1−yj)∗(1−πt)ptyj∗(1−pt1−yj)∗πt=μjt=ptyj∗(1−pt1−yj)∗πt+qtyj∗(1−qt1−yj)∗(1−πt)qtyj∗(1−qt1−yj)∗(1−πt)=1−μjt
P ( y j , z j = 1 ∣ θ ) = p y j ∗ ( 1 − p 1 − y j ) ∗ π P ( y j , z j = 0 ∣ θ ) = q y j ∗ ( 1 − q 1 − y j ) ∗ ( 1 − π ) \begin{aligned} P(y_j,z_j=1|\theta) = p^{y_j}*(1-p^{1-y_j})*\pi \\ P(y_j,z_j=0|\theta) = q^{y_j}*(1-q^{1-y_j})*(1-\pi) \end{aligned} P(yj,zj=1∣θ)=pyj∗(1−p1−yj)∗πP(yj,zj=0∣θ)=qyj∗(1−q1−yj)∗(1−π)
得到
Q
(
θ
,
θ
t
)
=
∑
j
=
1
N
{
μ
j
t
⋅
ln
[
π
⋅
p
y
j
(
1
−
p
)
1
−
y
j
]
+
(
1
−
μ
j
t
)
⋅
ln
[
(
1
−
π
)
q
y
j
(
1
−
q
)
1
−
y
j
]
}
Q\left(\theta, \theta^{t}\right)=\sum_{j=1}^{N}\left\{\mu_{j}^{t} \cdot \ln \left[\pi \cdot p^{y_{j}}(1-p)^{1-y_{j}}\right]+\left(1-\mu_{j}^{t}\right) \cdot \ln \left[(1-\pi) q^{y_{j}}(1-q)^{1-y_{j}}\right]\right\}
Q(θ,θt)=j=1∑N{μjt⋅ln[π⋅pyj(1−p)1−yj]+(1−μjt)⋅ln[(1−π)qyj(1−q)1−yj]}
对
π
\pi
π求偏导
∂
Q
∂
π
=
∑
j
=
1
N
{
μ
j
t
⋅
1
π
+
(
1
−
μ
j
t
)
⋅
−
q
y
j
(
1
−
q
)
1
−
y
j
(
1
−
π
)
q
y
j
(
1
−
q
)
1
−
y
j
}
=
∑
j
=
1
N
μ
j
t
−
π
π
⋅
(
1
−
π
)
=
0
\begin{aligned} \frac{\partial Q}{\partial \pi} &= \sum_{j=1}^{N}\{\mu_{j}^{t} \cdot\frac{1}{ \pi}+\left(1-\mu_{j}^{t}\right) \cdot \frac{-q^{y_{j}}(1-q)^{1-y_{j}}}{(1-\pi) q^{y_{j}}(1-q)^{1-y_{j}}} \} \\ &=\sum_{j=1}^{N}\frac{\mu_{j}^{t}-\pi}{\pi \cdot (1-\pi)} = 0 \end{aligned}
∂π∂Q=j=1∑N{μjt⋅π1+(1−μjt)⋅(1−π)qyj(1−q)1−yj−qyj(1−q)1−yj}=j=1∑Nπ⋅(1−π)μjt−π=0
得到
π
t
+
1
=
1
N
∑
j
=
1
N
μ
j
t
\pi^{t+1} = \frac{1}{N}\sum_{j=1}^{N}\mu_{j}^{t}
πt+1=N1j=1∑Nμjt
对
p
p
p求偏导
∂
Q
∂
p
=
∑
j
=
1
N
μ
j
t
⋅
∂
∂
p
[
ln
π
+
y
ln
p
+
(
1
−
y
)
ln
(
1
−
p
)
]
=
∑
j
=
1
N
u
j
t
⋅
(
y
p
+
1
−
y
1
−
p
)
=
∑
j
=
1
N
u
j
t
⋅
(
y
j
−
p
p
⋅
(
1
−
p
)
)
=
0
⇒
∑
j
=
1
N
u
j
t
⋅
y
j
−
∑
j
=
1
N
u
j
t
⋅
p
=
0
\begin{aligned} \frac{\partial Q}{\partial p} &= \sum_{j=1}^{N} \mu_j^{t} \cdot \frac{\partial}{\partial p}[\ln \pi+y \ln p+(1-y) \ln (1-p)]\\ &= \sum_{j=1}^{N}u_j^t\cdot(\frac{y}{p}+\frac{1-y}{1-p}) \\ &= \sum_{j=1}^{N}u_j^t\cdot(\frac{y_j-p}{p\cdot(1-p)}) = 0 \\ &\Rightarrow \sum_{j=1}^{N}u_j^t\cdot y_j-\sum_{j=1}^{N}u_j^t\cdot p = 0 \end{aligned}
∂p∂Q=j=1∑Nμjt⋅∂p∂[lnπ+ylnp+(1−y)ln(1−p)]=j=1∑Nujt⋅(py+1−p1−y)=j=1∑Nujt⋅(p⋅(1−p)yj−p)=0⇒j=1∑Nujt⋅yj−j=1∑Nujt⋅p=0
得到
p
t
+
1
=
∑
j
=
1
N
u
j
t
⋅
y
j
∑
j
=
1
N
u
j
t
p^{t+1} = \frac{\sum_{j=1}^{N}u_j^t\cdot y_j}{\sum_{j=1}^{N}u_j^t}
pt+1=∑j=1Nujt∑j=1Nujt⋅yj
对q求偏导(略),q和p是类似的,只要换掉权重
u
j
t
u_j^t
ujt为
1
−
u
j
t
1-u_j^t
1−ujt即可
得到
q
t
+
1
=
∑
j
=
1
N
(
1
−
u
j
t
)
⋅
y
j
∑
j
=
1
N
(
1
−
u
j
t
)
q^{t+1} = \frac{\sum_{j=1}^{N}(1-u_j^t)\cdot y_j}{\sum_{j=1}^{N}(1-u_j^t)}
qt+1=∑j=1N(1−ujt)∑j=1N(1−ujt)⋅yj
至此,我们已经获得了
π
,
p
,
q
\pi,p,q
π,p,q的递推式
π
t
+
1
=
1
N
∑
j
=
1
N
μ
j
t
p
t
+
1
=
∑
j
=
1
N
u
j
t
⋅
y
j
∑
j
=
1
N
u
j
t
q
t
+
1
=
∑
j
=
1
N
(
1
−
u
j
t
)
⋅
y
j
∑
j
=
1
N
(
1
−
u
j
t
)
\begin{aligned} &\pi^{t+1} = \frac{1}{N}\sum_{j=1}^{N}\mu_{j}^{t} \\ &p^{t+1} = \frac{\sum_{j=1}^{N}u_j^t\cdot y_j}{\sum_{j=1}^{N}u_j^t} \\ &q^{t+1} = \frac{\sum_{j=1}^{N}(1-u_j^t)\cdot y_j}{\sum_{j=1}^{N}(1-u_j^t)} \end{aligned}
πt+1=N1j=1∑Nμjtpt+1=∑j=1Nujt∑j=1Nujt⋅yjqt+1=∑j=1N(1−ujt)∑j=1N(1−ujt)⋅yj
得到递推式后,我们就可以反复迭代这几个式子,直到
π
t
+
1
,
p
t
+
1
,
q
t
+
1
≈
π
t
,
p
t
,
q
t
\pi^{t+1},p^{t+1},q^{t+1} \approx \pi^{t},p^{t},q^{t}
πt+1,pt+1,qt+1≈πt,pt,qt
EM算法实现
下面我们用代码实现一下三硬币模型
首先我们先定义数据集
假设我们投了20次A硬币,得到B、C的投掷结果如下
observations = [1,0,0,1,1,1,0,1,1,1,0,0,0,0,1,1,0,1,0,0]
初始A,B,C正面朝上的概率分别为0.4,0.6,0.5
pi,p,q = 0.4,0.6,0.5
遍历观测序列
def em(pi,p,q,max_iter = 100,toler = 0.001):
for epoch in range(max_iter):
p_up,p_down,q_up,q_down = .0,.0,.0,.0
for observation in observations:
if observation==1:
ut = p*pi/(p*pi+q*(1-pi))
else:
ut = (1-p)*pi/((1-p)*pi+(1-q)*(1-pi))
p_up += ut*observation
p_down += ut
q_up += (1-ut)*observation
q_down += 1-ut
pi_next = p_down/len(observations)
p_next = p_up/p_down
q_next = q_up/q_down
# if np.abs(pi-pi_next)<toler and np.abs(p-p_next)<toler and np.abs(q-q_next)<toler:
if pi==pi_next and p==p_next and q==q_next:
pi = pi_next
p = p_next
q = q_next
break
else:
pi = pi_next
p = p_next
q = q_next
print("epoch %s:%s %s %s"%(epoch+1,pi,p,q))
return pi,p,q
执行em函数
em(pi,p,q)
输出
epoch 1:0.41711229946524064 0.435897435897436 0.5458715596330275
epoch 2:0.4171122994652407 0.43589743589743596 0.5458715596330275
epoch 3:0.4171122994652408 0.43589743589743596 0.5458715596330275
(0.4171122994652408, 0.43589743589743596, 0.5458715596330275)