基于 EM 算法的线性拟合-Python实现

基于 EM 算法的线性拟合-Python实现

理论

具体参考论文:
王礼想.基于EM算法的线性拟合问题研究[J].廊坊师范学院学报(自然科学版),2013,13(04):21-23.

假定已有一组相对分散的数据,要对数据进行线性拟合,需要解决两个问题:
(1)确定直线条数、估计各直线参数——斜率与截距;
(2)为每个点分配直线。
直觉上,对于这两个问题若事先知道其中任何一个,另一问题可迎刃而解。即若已知每点的归属,把点代入所属直线,则直线的斜率和截距就解决了;若已知直线方程,把各点代入,也容易判断各点所属直线。由这一启发,下面给出 EM 算法流程:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

Python 代码

生成点

#随便生成两类点
t = np.linspace(-0.5, 1, 50)
m1 = 120*t+100+np.random.normal(loc=0.0, scale=30.0,size = 50);m2 = 20*t+30+np.random.normal(loc=0.0, scale=30,size = 50)
#两类点混一起为xi,yi
xi = np.append(t,t);yi = np.append(m1,m2)

EM算法

a1,b1 = 10,20; a2,b2 = 80,30 #初始值
e = 1
r1i = a1*xi+b1-yi; r2i = a2*xi+b2-yi

while abs(e)>0.1:#循环至稳定
    #E步
    sigma2 = sum(r1i**2)+sum(r2i**2)
    s = np.exp(-r1i**2/sigma2) + np.exp(-r2i**2/sigma2)
    w1i = np.exp(-r1i**2/sigma2)/s; w2i = np.exp(-r2i**2/sigma2)/s
    x1 = xi[w1i>w2i];x2 = xi[w1i<=w2i]
    y1 = yi[w1i>w2i];y2 = yi[w1i<=w2i]
    w1 = w1i[w1i>w2i];w2 = w2i[w1i<=w2i]
    #M步
    a1,b1 = np.linalg.solve(np.array([[x1**2@w1,x1@w1],[x1@w1,sum(w1)]]), np.array([[x1*y1@w1],[y1@w1]]))
    a2,b2 = np.linalg.solve(np.array([[x2**2@w2,x2@w2],[x2@w2,sum(w2)]]), np.array([x2*y2@w2,y2@w2]))
    r1i = a1*xi+b1-yi; r2i = a2*xi+b2-yi
    #更新e
    e = sigma2-sum(r1i**2)-sum(r2i**2)

绘图看看

#原始的两类点
plt.scatter(t,m2,s=3)
plt.scatter(t,m1,s=3)
#EM算法拟合出来的两条直线
t1 = a1*t+b1;t2 = a2*t+b2
plt.plot(t,t1)
plt.plot(t,t2)
plt.show()

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值