机器学习-EM算法

EM算法(最大期望算法[Expectation-Maximization algorithm, EM])是为处理缺失数据的参数的估计问题,主要分为E步和M步交替组成,对给定的初始参数依赖较大。

Python代码实现

from numpy import *
import numpy as np
import matplotlib.pyplot as plt
import random


def create_sample_data(m, n):
    mat_y = mat(zeros((m, n)))

    for i in range(m):
        for j in range(n):
            # 通过产生随机数,每一行表示一次实验结果
            mat_y[i, j] = random.randint(0, 1)
    return mat_t


# EM算法
def em(arr_y, theta, tol, iterator_num):
    PI = 0
    P = 0
    Q = 0
    m, n = shape(arr_y)
    mat_y = arr_y.getA()

    for i in range(iterator_num):
        miu = []
        PI = copy(theta[0])
        P = copy(theta[1])
        Q = copy(theta[2])
        for j in range(m):
            miu_value = (PI * (P ** mat_y[j]) * ((1 - P) ** (1 - mat_y[j]))) / \
                        (PI * (P ** mat_y[j]) * ((1 - P) ** (1 - mat_y[j])) + (1 - PI) * (Q ** mat_y[j]) * (
                                    (1 - Q) ** (1 - mat_y[j])))
            miu.append(miu_value)

        sum1 = 0.0
        for j in range(m):
            sum1 += miu[j]
        theta[0] = sum1 / m

        sum1 = 0.0
        sum2 = 0.0
        for j in range(m):
            sum1 += miu[j] * mat_y[j]
            sum2 += miu[j]
        theta[1] = sum1 / sum2

        sum1 = 0.0
        sum2 = 0.0
        for j in range(m):
            sum1 += (1 - miu[j]) * mat_y[j]
            sum2 += (1 - miu[j])
        theta[2] = sum1 / sum2

        print("------------------------------------------")
        print(theta)
        if (abs(theta[0] - PI) <= tol and abs(theta[1] - P) <= tol \
                and abs(theta[2] - Q) <= tol):
            print("break")
            break
    return PI, P, Q


def main():
    # mat_y = create_sample_data(100, 1)
    mat_y = mat(zeros((10, 1)))
    mat_y[0, 0] = 1
    mat_y[1, 0] = 1
    mat_y[2, 0] = 0
    mat_y[3, 0] = 1
    mat_y[4, 0] = 0
    mat_y[5, 0] = 0
    mat_y[6, 0] = 1
    mat_y[7, 0] = 0
    mat_y[8, 0] = 1
    mat_y[9, 0] = 1
    #theta 三个参数可以自行更改,然后对比一下结果
    theta = [0.4, 0.6, 0.7]
    print(mat_y)
    PI, P, Q = em(mat_y, theta, 0.001, 100)
    print(PI, P, Q)

main()

参考

《统计学习方法》—— 李航

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值