EM算法的简易实现
em算法是博主在转到软件工程专业所尝试的第一个算法,用的是C语言,虽然算法很简单但还是费了很大功夫才写出来,写出来已经很久了,但现在才想起来记录一下。
1.论文引入
2.数据集
3.代码实现
#include<stdio.h>
#include<math.h>
double Estep1(double num1, double num2, double num3, double num4);
double Estep2(double num1, double num2, double num3, double num4);
double Mstep1(double numb1, double numb2);
int main()
{
int a[5] = { 0 };//正面次数
int b[5] = { 0 };//反面次数
double ap[5] = { 0 };//正面概率
double bp[5] = { 0 };//反面概率
int i, j, n;
int k = 0;
double p1 = 0.6;//假定的p1
double p2 = 0.5;//假定的p2
double s,s1=0, s2=0, s3=0, s4=0,temp,item;
int arr[5][10] =
{
{1,0,0,0,1,1,0,1,0,1},
{1,1,1,1,0,1,1,1,1,1},
{1,0,1,1,1,1,1,0,1,1},
{1,0,1,0,0,0,1,1,0,0},
{0,1,1,1,0,1,1,1,0,1}
};
/*统计正反次数*/
for (i = 0; i < 5; i++)
{
for (j = 0; j < 10; j++)
{
if (arr[i][j] == 1)
a[k] = a[k] + 1;
else
b[k] = b[k] + 1;
}
printf("正面次数为%d,反面次数为%d\n", a[k], b[k]);
k++;
}
printf("设 p1=0.6, p2=0.5\n");//假设的p1p2
/*统计是A B的概率*/
for (n = 0; n < 10; n++)
{
for (k = 0; k < 5; k++)
{
ap[k] = Estep1(p1, p2, a[k], b[k]);//组是硬币A的概率
bp[k] = Estep2(p1, p2, a[k], b[k]);//组是硬币B的概率
printf("%.2lf %.2lf \n", ap[k], bp[k]);
}
/*计算概率*/
for (k = 0; k < 5; k++)
{
temp = Mstep1(a[k], ap[k]);
s1 = s1 + temp;
item = Mstep1(b[k], ap[k]);
s2 = s2 + item;
}
for (k = 0; k < 5; k++)
{
temp = Mstep1( bp[k],a[k]);
s3 = s3 + temp;
item = Mstep1( bp[k],b[k]);
s4 = s4 + item;
}
p1 = s1 / (s1 + s2);
p2 = s3 / (s3 + s4);
printf("A的概率为:%.2lf,B的概率为:%.2lf\n", p1, p2);
}
/*根据数据对结果从新估计+*/
printf("最终结果为%.2lf和%.2lf\n", p1,p2);
for (k = 0; k < 5;k++)
{
s = a[k]*1.0 / 10;
printf("s=%.2lf\n", s);
if (fabs(s - p1) < fabs(s - p2))
{
printf("第%d组为硬币A\n", k+1);
}
else
{
printf("第%d组为硬币B\n",k+1);
}
}
}
double Estep1(double num1, double num2, double num3, double num4)
{/*p1=1 p2=2 n1=3 n1f=4 */
double a1, b1, ap1, bp1;
a1 = pow(num1, num3) * pow((1.0 - num1), num4);
b1 = pow(num2, num3) * pow((1.0 - num2), num4);
ap1 = a1 / (a1 + b1);//第一组是硬币A的概率
bp1 = 1.0 - ap1;//第一组是硬币B的概率
return ap1;
}
double Estep2(double num1, double num2, double num3, double num4)
{
double a1, b1, ap1, bp1;
a1 = pow(num1, num3) * pow((1.0 - num1), num4);
b1 = pow(num2, num3) * pow((1.0 - num2), num4);
ap1 = a1 / (a1 + b1);
bp1 = 1.0 - ap1;
return bp1;
}
double Mstep1(double numb1,double numb2)
{
double s1;
s1 = numb1 * numb2;
return s1;
}
结果
4.总结
中间还是有很多需要完善的地方,最后结果也很不理解为什么一直不对,不知道是不是精度问题,还很菜,需要继续努力。