《统计学习方法》第9章_EM算法及推广

代码实现如下:

# encoding:utf-8
"""E step: 参考统计学习方法P式(9.5)"""
import math

class EM:
    """初始化模型参数"""
    def __init__(self, prob):
        """pro_A,pro_B,pro_C分别表示硬币A,B,C正面出现的概率"""
        self.pro_A, self.pro_B, self.pro_C = prob

    """E步,求期望(实现式(9.5))"""
    def pmf(self, i, data):
        """式(9.5)分子的实现"""
        pro_1 = self.pro_A * math.pow(self.pro_B, data[i]) * math.pow((1 - self.pro_B), 1 - data[i])
        """式(9.5)分母的实现"""
        pro_2 = (1 - self.pro_A) * math.pow(self.pro_C, data[i]) * math.pow((1- self.pro_C), (1 - data[i]))
        """返回式(9.5)的计算结果"""
        return pro_1 / (pro_1 + pro_2)

    """M步,求极大化,迭代求出模型参数的新估计值"""
    def fit(self, data):
        count = len(data)
        print('init prob:{}, {}, {}'.format(self.pro_A, self.pro_B, self.pro_C))
        for d in range(count):
            _ = yield
            _pmf = [self.pmf(k, data) for k in range(count)]
            """计算π(i+1)的新估计值"""
            pro_A = 1 / count * sum(_pmf)
            """计算p(i+1)的新估计值"""
            pro_B = sum([self.pmf(k, data) * data[k] for k in range(count)]) / sum(
                [self.pmf(k, data) for k in range(count)])
            """计算q(i+1)的新估计值"""
            pro_C = sum([(1 - self.pmf(k, data)) * data[k] for k in range(count)]) / sum(
                [(1 - self.pmf(k, data)) for k in range(count)])
            print('{}/{} pro_a:{:.4f}, pro_b:{:.4f}, pro_c:{:.4f}'.format(d + 1, count, pro_A, pro_B, pro_C))
            self.pro_A = pro_A
            self.pro_B = pro_B
            self.pro_C = pro_C


data = [1, 1, 0, 1, 0, 0, 1, 0, 1, 1]

em = EM(prob=[0.5, 0.5, 0.5])
f = em.fit(data)
next(f)

"""第一次迭代"""
f.send(1)

"""第二次迭代"""
f.send(2)

"""更换初始值"""
em = EM(prob=[0.4, 0.6, 0.7])
f2 = em.fit(data)
next(f2)

"""第一次迭代"""
f2.send(1)

"""第二次迭代"""
f2.send(2)

运行结果:

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值