为了即将到来的考试,还是整理一下EM算法吧。
EM算法是一个统计中巧妙得处理有缺失数据的统计推断问题的优化方法,我们预想我们的完全数据分为两部分,一部分为直接观测的数据,一部分为无法直接观测的数据。即
,这里
分别是观测数据和不可观测数据。对于给定的观测数据
,我们的统计推断问题通常想极大似然
,但是它通常难以处理,具体可从以下可以看到,由于它要把
积分掉。EM算法采用比较容易处理的
的密度函数,从而巧妙得避开了直接处理
。
以上的两张大致介绍了EM的想法以及算法过程,下面用一张图片来直观得展示EM算法在迭代过程中的表现:
可以看出,每次参数迭代更新的过程中,参数
都保证了似然函数上升(非降)。虽然这样并不能保证一定可以达到全局极值,但是依然是一种十分快速且有效的优化算法。
说完了理论,下面我们来看一下具体的例子帮助理解。
EM算法一共有三个步骤:
- 写出完全数据的似然函数
- 求出关于当前参数和观测数据的条件期望
- 对于第二步中的条件期望进行求导,得出关于参数的迭代公式
下面看一下在这个例子中是怎么做的。
这样我们只需呀给定初值就可以通过最后一共公式迭代计算参数的估计值。
我们考虑如下的观测数据:
给出算法的具体实现。
import numpy as np
import matplotlib.pyplot as plt
X = np.array([34,18,20,125]) # obs data
itr = 40 # iteration numbers
theta_es = np.zeros(itr)
t = 1 # initial value
def update(x):
y = (X[0]+X[3]*(x/(2+x)))/(X[0]+X[1]+X[2]+X[3]*(x/(2+x)))
return y
for i in range(itr):
t = update(t)
theta_es[i] = t
plt.plot(theta_es)
plt.title('theta estimation')
plt.xlabel('Iteration numbers')
plt.ylabel('Estimation value')
可以看出最后参数
稳定在了
。
关于EM算法还有很多东西,就以后再写啦!