题目要求:
给出解决代码如下:
import numpy as np
def miu_calc(pii, pi, qi, yj):
up_b = pii * pi ** yj * (1 - pi) ** (1 - yj)
up_c = (1 - pii) * qi ** yj * (1 - qi) ** (1 - yj)
return up_b / (up_b + up_c)
yj = np.array([1, 1, 0, 1, 0, 0, 1, 0, 1, 1])
pi0 = 0.46
p0 = 0.55
q0 = 0.67
miu = miu_calc(pi0, p0, q0, yj)
pi_n = miu.mean()
p_n = (miu * yj).sum() / miu.sum()
q_n = ((1 - miu) * yj).sum() / (1 - miu).sum()
for ii in range(100):
miu = miu_calc(pi_n, p_n, q_n, yj)
pi_n = miu.mean()
p_n = (miu * yj).sum() / miu.sum()
q_n = ((1 - miu) * yj).sum() / (1 - miu).sum()
print(pi_n, p_n, q_n)
算法解析:
def miu_calc(pii, pi, qi, yj):
up_b = pii * pi ** yj * (1 - pi) ** (1 - yj)
up_c = (1 - pii) * qi ** yj * (1 - qi) ** (1 - yj)
return up_b / (up_b + up_c)
for ii in range(100):
miu = miu_calc(pi_n, p_n, q_n, yj)
pi_n = miu.mean()
p_n = (miu * yj).sum() / miu.sum()
q_n = ((1 - miu) * yj).sum() / (1 - miu).sum()
最终得到结果:
代码来源及参考文章:统计学习方法 EM算法三硬币问题python代码 - 知乎