EM算法(Expectation-Maximum)与双硬币问题的求解(附代码实现)
EM算法是一种迭代优化策略,由于它的计算方法中每一次迭代都分两步,其中一个为期望步(E步),另一个为极大步(M步),所以算法被称为EM算法(Expectation-Maximization Algorithm),最初是为了解决数据缺失情况下的参数估计问题。双硬币问题
可以作为一个非常好的实例:
假如有五枚硬币,没枚硬币抛十次,记录硬币向上和向下的次数。硬币总共有A,B两种,在事先我们已经知道了这五枚硬币的种类分别是什么,那么我们可以很轻松的算出来两种硬币朝上的概率(由于硬币本身原因,朝上的概率不为50%)。但是假如我们一开始并不知道这五枚硬币的种类,那么问题就变得棘手了,EM算法则可以帮助我们解决这个问题。
EM算法求解双硬币问题(c++实现):
//假设有A,B两种硬币,现在得到的数据中并没有硬币种类的信息和硬币正面朝上的信息
#include <iostream>
#include "math.h"
using namespace std;
//用于阶乘的函数
constexpr int f(int i) {
return i < 2 ? 1 : i * f(i-1);
}
int main(){
//得到的硬币数据展示
int samples[5][10]={{1,0,0,0,1,1,0,0,1,1},
{1,1,1,1,1,1,1,1,0,0},
{1,1,1,1,1,1,1,1,0,0},
{1,1,1,1,0,0,0,0,0,0},
{1,1,1,1,1,1,1,0,0,0},
};
//初始化A,B两种硬币的朝上概率
double prob_A = 0.6; double prob_B = 0.5;
double front[5];//front[i]表示对于第i组,硬币朝上的概率
//根据得到的数据,算出front
for(int j = 0; j< 5; j++){
front[j] = 0;
for(int k = 0; k< 10; k++){
front[j] += samples[j][k];
}
front[j] = front[j]/10;
}
//假设循环1000次后能够得到较为精确的解
for(int i = 0; i < 1000 ;i++){
//计算第i组硬币为A和为B的概率
double contribution_A = 0; double contribution_B = 0;
double weight_A[5]; double weight_B[5];
double num_AH[5]; double num_BH[5];
double num_AT[5]; double num_BT[5];
for(int n = 0; n < 5; n++){
contribution_A = f(10)/(f(front[n]*10)*f(10-front[n]*10))*pow(prob_A,(front[n]*10))*pow((1-prob_A),(10-front[n]*10));
contribution_B = f(10)/(f(front[n]*10)*f(10-front[n]*10))*pow(prob_B,(front[n]*10))*pow((1-prob_B),(10-front[n]*10));
weight_A[n] = contribution_A/(contribution_A+contribution_B);
weight_B[n] = contribution_B/(contribution_A+contribution_B);
num_AH[n] = weight_A[n]*front[n]*10;
num_BH[n] = weight_B[n]*front[n]*10;
num_AT[n] = weight_A[n]*(1-front[n])*10;
num_BT[n] = weight_B[n]*(1-front[n])*10;
}
double sum_AH = 0; double sum_BH = 0;
double sum_AT = 0; double sum_BT = 0;
for(int n = 0; n < 5; n++){
sum_AH += num_AH[n];
sum_BH += num_BH[n];
sum_AT += num_AT[n];
sum_BT += num_BT[n];
}
prob_A = sum_AH/(sum_AH + sum_AT);
prob_B = sum_BH/(sum_BH + sum_BT);
}
cout << "prob_A:" << prob_A << " ";
cout << "prob_B:" << prob_B << " ";
cout << "\n";
}
运行结果:
Reference:
①:http://www.zzvips.com/article/108550.html
②:https://zhuanlan.zhihu.com/p/40991784
③:https://blog.csdn.net/zhihua_oba/article/details/73776553