1.EM算法是什么
EM算法可以用于有监督学习,也可以用于无监督学习。这个算法是根据观测结果求得对含有隐变量的模型的参数的估计。包含E步骤和M步,E步是求期望,M步是求极大似然估计,极大参数估计是对模型参数估计的一种方法。一个典型的应用EM算法进行参数估计的例子就是敏感问题的调查,我们想要得到人群中吸烟人数的比例,可以设置这样一个问卷
- 问题1:你的手机尾号是偶数吗?若是,回答问题2,不是,则回答问题3
- 问题2:你吸烟吗
- 问题3:你喜欢养猫吗
通过调查,我们获得结果是:
是 | 否 | |
---|---|---|
问题2 | N1 | N2 |
问题3 | N3 | N4 |
其中,N1,N2,N3,N4表示的是人数,在这里隐变量就是手机尾号。根据观测的结果进行人群中吸烟人群比例的估计。这是一个比较通俗的例子来理解EM算法。接下来对李航老师的《统计学习方法》书中的抛硬币的例子进行python求解。
2.EM算法的Python实现
输入:原始数据,模型的初始化的参数
输出:模型的参数
训练:就是根据训练数据进行模型参数调整的过程
import numpy as np
class EM():
def __init__(self,theta1,theta2,theta3,epochs):
self.theta1 = theta1
self.theta2 = theta2
self.theta3 = theta3
self.epochs = 20
# 遍历每一个样本,求对应标签以及概率
def prob_maix_function(self,data):
numerator = np.zeros(len(data))
denominator = np.zeros(len(data))
for i in range(len(data)):
numerator[i] = self.theta1*pow(self.theta2,data[i])*pow((1-self.theta2),1-data[i])
denominator[i] = (1-self.theta1) * pow(self.theta3,data[i])*pow((1-self.theta3),1-data[i])
mu = numerator/(numerator+denominator)
return mu
def fit(self,data):
self.x_train = data
n = len(self.x_train)
thre = 2
for epoch in range(self.epochs): ## 遍历数据集更新参数的个数,相当于full batch都用来进行模型参数的更新
mu_p = self.prob_maix_function(self.x_train)
theta_1_ = 1/n * sum(mu_p)
theta_2_ = sum(mu_p * self.x_train)/ sum(mu_p)
theta_3_ = sum((1-mu_p) * self.x_train)/ sum(mu_p)
delta = abs(self.theta1 - theta_1_) + abs(self.theta2-theta_2_) + abs(self.theta3 -theta_3_)
if delta >= 1e-4:
self.theta1 = theta_1_
self.theta2 = theta_2_
self.theta3 = theta_3_
else:
break
print('Epoch: %.f' % (epoch))
print('Theta1: %.4f' % (self.theta1))
print('Theta2: %.4f' % (self.theta2))
print('Theta3: %.4f' % (self.theta3))
print('-----------------------------------')
if __name__ =='__main__':
data = np.array([1, 1, 0, 1, 0, 0, 1, 0, 1, 1])
prob1 = 0.46
prob2 = 0.55
prob3 = 0.67
epochs = 20
clf = EM(prob1,prob2,prob3,epochs)
clf.fit(data)
我在这里设置了epoch=20,这个是遍历数据集的次数,没遍历一次数据集,进行一次模型参数的更新,训练是否停止的限制是通过语句 if delta >= 1e-4:实现的,如果误差很小,就停止训练,输出模型参数。
3.总结
理解隐变量的作用,以及隐藏变量下的条件概率的计算。最大化似然函数转换成最大化似然函数的下界问题,这个转换是通过Jesen不等式完成的。