EM算法是一个很经典的算法,有人成为上帝算法,可以可以在你不知道样本类别的情况下求出该样本的类别,前提你需要知道样本服从什么分布。是一种经典的无监督学习算法。
平常我们求解最优问题,通常采用最小二乘法,梯度下降法,高斯牛顿法,牛顿法,拟牛顿法,列-马算法等等。但是在使用这些方法之前通常会使用极大似然估计或者拉格朗日乘子法作为前序,同样EM算法也是极大似然估计的后续。极大似然估计是把累乘问题通过对数似然函数转化为累加问题,然后用梯度下降法或者其他算法求解最值问题。拉格朗日乘子法主要是为了解决偏导为0无法求解的问题,通过引入拉格朗日乘子来求解。
EM算法是通过E步和M步求解问题,E步就是为了消除隐变量,即计算出样本属于某种类别的概率。然后使用M步更新迭代参数。循环迭代直到参数波动范围很小,或者达到指定的迭代次数。
EM算法的应用,比如经典的高斯混合模型(GMM)算法上面的应用。隐马尔科夫的应用等等。
下面图片是我手推结果和解释:
该图是EM一个简单示例,了解什么是EM算法。
经典的EM算法的说明解释示例,硬币问题。包括训练和预测代码。
EM算法的推导,以及Jensen不等式的简单介绍说明:
这里上面硬币问题的训练代码和预测代码,核心代码使用C语言实现的,如果要完整代码请联系我。
//阶乘代码
double CFactorial(const int len, const int m){
if (len < m || len < 1 || len < m) { return 0.0; }
double numerator = 1.0, denominator = 1.0;
for (size_t i = 1; i <= m; i++) {
numerator *= 1.0 * (len - i + 1);
denominator *= 1.0 * i;
}
return numerator / denominator;
}
//EM算法求解硬币问题
void EMTrain(const EMDATA *data, const int cls_num, EMRESULT *theta){
if (cls_num < 1 || data->width < 1 || data->height < 1){ TFLYERROR("size error\n"); return; }
const EMDATA *head_data = data;
int *per_group_cls_num = (int *)malloc(head_data->height * sizeof(int)); //硬币问题,只有两种类别。记录正面朝上的硬币个数
memset(per_group_cls_num, 0, head_data->height * sizeof(int));
int *head_sum = per_group_cls_num;
for (size_t i = 0; i < head_data->height; i++) { //统计每组数据中硬币朝上的个数
for (size_t j = 0; j < head_data->width; j++) {
*head_sum += head_data->data[i][j];
}
head_sum++;
}
double pA = 0.0, pB = 0.0;
double thetaA = theta[0], thetaB = theta[1], error = 1e-8; //表示theta的error表示误差
double tmp_thetaA = thetaA, tmp_thetaB = thetaB;
while (1){
//硬币A正面和反面出现的次数 与 硬币B正面和反面出现的次数
double sum_frontA = 0.0, sum_backA = 0.0;
double sum_frontB = 0.0, sum_backB = 0.0;
//E-step 求解
head_sum = per_group_cls_num;
for (size_t i = 0; i < head_data->height; i++) { //正样本
double fact = CFactorial(head_data->width, *head_sum);
pA = fact * pow(thetaA, *head_sum) * pow(1 - thetaA, head_data->width - *head_sum);
pB = fact * pow(thetaB, *head_sum) * pow(1 - thetaB, head_data->width - *head_sum);
pA = pA / (pA + pB); //更新硬币A的概率
pB = 1 - pA; //更新硬币B的概率
sum_frontA += pA * *head_sum; //硬币A正面出现的次数累加和
sum_backA += pA * (head_data->width - *head_sum); //硬币A背面出现的次数累加和
sum_frontB += pB * *head_sum; //硬币B正面出现的次数累加和
sum_backB += pB * (head_data->width - *head_sum); //硬币B背面出现的次数累加和
++head_sum;
}
//M-step 更新参数thetaA和thetaB
thetaA = sum_frontA / (sum_frontA + sum_backA);
thetaB = sum_frontB / (sum_frontB + sum_backB);
printf("thetaA=%lf, thetaB=%lf\n", thetaA, thetaB);
if (fabs(thetaA - tmp_thetaA) < error && fabs(thetaB - tmp_thetaB) < error){ //迭代终止条件,也可以使用
break;
}
tmp_thetaA = thetaA;
tmp_thetaB = thetaB;
}
*theta++ = thetaA; //返回最终的结果
*theta++ = thetaB;
free(per_group_cls_num);
return;
}
//模型预测结果
void EMPredicted(const EMDATA *pred_data, const int cls_num, const EMRESULT *theta, int *pred_cls){
if (cls_num < 1 || pred_data->width < 1 || pred_data->height < 1){ TFLYERROR("size error\n"); return; }
const EMDATA *head_data = pred_data;
int *per_group_cls_num = (int *)malloc(head_data->height * sizeof(int)); //硬币问题,只有两种类别。记录正面朝上的硬币个数
memset(per_group_cls_num, 0, head_data->height * sizeof(int));
int *head_sum = per_group_cls_num;
for (size_t i = 0; i < head_data->height; i++) { //统计每组数据中硬币朝上的个数
for (size_t j = 0; j < head_data->width; j++) {
*head_sum += head_data->data[i][j];
}
head_sum++;
}
double *cls_value = (double *)malloc(cls_num * sizeof(double));//每种类别的类别和
head_sum = per_group_cls_num;
for (int i = 0; i < head_data->height; i++) {
memset(cls_value, 0, cls_num * sizeof(double));
double *head_cls_value = cls_value;
double fact = CFactorial(head_data->width, *head_sum);
const EMRESULT *head_theta = theta;
double sum = 0.0;
//方法1
double max_prob = -1.0; //概率值大于等于0
int real_class = 0;
for (size_t j = 0; j < cls_num; j++) {
*head_cls_value = fact * pow(*head_theta, *head_sum) * pow(1 - *head_theta, head_data->width - *head_sum);
if (max_prob < *head_cls_value) { max_prob = *head_cls_value; real_class = j; }
++head_cls_value; ++head_theta;
}
//方法2
//for (size_t j = 0; j < cls_num; j++) {
// *head_cls_value = fact * pow(*head_theta, *head_sum) * pow(1 - *head_theta, head_data->width - *head_sum);
// sum += *head_cls_value;
// ++head_cls_value; ++head_theta;
//}
//head_cls_value = cls_value;
//double max_prob = -1.0; //概率值大于等于0
//int real_class = 0;
//for (size_t j = 0; j < cls_num; j++) {
// double per_prob = *head_cls_value / sum; //同一组数据,每种类别的概率,取概率最大值作为预测结果
// if (max_prob < per_prob) { max_prob = per_prob; real_class = j; }
// ++head_cls_value;
//}
*pred_cls++ = real_class;
++head_sum;
}
free(cls_value);
free(per_group_cls_num);
}