EM算法具体过程看前一篇博客
一、matlab实现
1.matlab代码
close all;
clear;
clc;
%%
M=3; % 高斯数量
N=600; % 数据样本总数
th=0.000001; % 聚合阀值
K=2; % 输出信号保留
% 待生成数据的参数
a_real =[2/3;1/6;1/6];
mu_real=[3 4 6;
5 3 7];
cov_real(:,:,1)=[5 0;
0 0.2];
cov_real(:,:,2)=[0.1 0;
0 0.1];
cov_real(:,:,3)=[0.1 0;
0 0.1];
% 这里生成的数据全部符合标准
x=[ mvnrnd( mu_real(:,1) , cov_real(:,:,1) , round(N*a_real(1)) )' ,...
mvnrnd( mu_real(:,2) , cov_real(:,:,2) , round(N*a_real(2)) )' ,...
mvnrnd( mu_real(:,3) , cov_real(:,:,3) , round(N*a_real(3)) )' ];
figure(1),plot(x(1,:),x(2,:),'.')
%% EM
% 参数初始化
a=[1/3,1/3,1/3]; %各类的比例(权重)
mu=[1 2 3; %均值初始化
2 1 4];
cov(:,:,1)=[1 0; %协方差初始化
0 1];
cov(:,:,2)=[1 0;
0 1];
cov(:,:,3)=[1 0;
0 1];
t=inf;
count=0;
figure(2),hold on
while t>=th
a_old = a;
mu_old = mu;
cov_old= cov;
rznk_p=zeros(M,N);%生成M行N列零矩阵
for cm=1:M
mu_cm=mu(:,cm);
cov_cm=cov(:,:,cm);
for cn=1:N %计算Pi(x)
p_cm=exp(-0.5*(x(:,cn)-mu_cm)'/cov_cm*(x(:,cn)-mu_cm));
rznk_p(cm,cn)=p_cm;
end
rznk_p(cm,:)=rznk_p(cm,:)/sqrt(det(cov_cm));
end
rznk_p=rznk_p*(2*pi)^(-K/2);
%E step
%开始求rznk 相当于Pr(i|Xt)
rznk=zeros(M,N);%r(Z
pikn=zeros(1,M);%r(Z
pikn_sum=0;
for cn=1:N
for cm=1:M%计算p(x|*)概率分布
pikn(1,cm)=a(cm)*rznk_p(cm,cn);
% pikn_sum=pikn_sum+pikn(1,cm);
end
for cm=1:M
rznk(cm,cn)=pikn(1,cm)/sum(pikn);
end
end
%求rank结束
% M step
nk=zeros(1,M);
for cm=1:M
for cn=1:N
nk(1,cm)=nk(1,cm)+rznk(cm,cn);
end
end
a=nk/N;
rznk_sum_mu=zeros(M,1);
% 求均值MU
%nk(cm) 相当于ni
%rznk_sum_mu ni*Xt
for cm=1:M
rznk_sum_mu=0;
for cn=1:N
rznk_sum_mu=rznk_sum_mu+rznk(cm,cn)*x(:,cn);
end
mu(:,cm)=rznk_sum_mu/nk(cm);
end
% 求协方差COV
for cm=1:M
rznk_sum_cov=zeros(K,K);
for cn=1:N%求协方差(Hi-u)^2
rznk_sum_cov=rznk_sum_cov+rznk(cm,cn)*(x(:,cn)-mu(:,cm))*(x(:,cn)-mu(:,cm))';
end
cov(:,:,cm)=rznk_sum_cov/nk(cm);
end
t=max([norm(a_old(:)-a(:))/norm(a_old(:));norm(mu_old(:)-mu(:))/norm(mu_old(:));norm(cov_old(:)-cov(:))/norm(cov_old(:))]);
temp_f=sum(log(sum(pikn)));
plot(count,temp_f,'r+')
count=count+1;
end %while
hold off
f=sum(log(sum(pikn)));
% 输出结果
a
mu
cov
figure(3),
hold on
plot(x(1,:),x(2,:),'k.');
plot(mu_real(1,:),mu_real(2,:),'*c');
plot(mu(1,:),mu(2,:),'+r');
hold off
figure(4),
hold on
for i=1:N
[max_temp,ind_temp]=max(rznk(:,i));
if ind_temp==1
plot(x(1,i),x(2,i),'k.');
end
if ind_temp==2
plot(x(1,i),x(2,i),'b.');
end
if ind_temp==3
plot(x(1,i),x(2,i),'r.');
end
end
%fcm聚类
[center, U, OBJ_FCN]=fcm(x',3);
figure(5),
hold on
for i=1:N
[max_temp,ind_temp]=max(U(:,i));
if ind_temp==1
plot(x(1,i),x(2,i),'k.');
end
if ind_temp==2
plot(x(1,i),x(2,i),'b.');
end
if ind_temp==3
plot(x(1,i),x(2,i),'r.');
end
end
plot(center(:,1),center(:,2),'c*')
hold off
2.结果显示
二、Java实现
编译工具为:eclipse
1.Java代码
package cn.sxt.oo2;
/**
* 一維情況下的EM算法實現
*1、求期望(e-step)
*2、期望最大化(估值)(M-step)
*3、循環以上兩部直到收斂
*/
public class MyEM {
private static final double[] points={1.0,1.3,2.2,2.6,2.8,5.0,7.3,7.4,7.5,7.7,7.9};
private static double[][] w;//权值
private static double[] means = {7.7,2.3};//均值
private static double[] variances= {1,1};//方差
private static double[] probs = {0.5,0.5};//每个类的概率;这里默认选择k=2了;
/**
* 高斯分布计算公式,也就是先验概率
*/
//p(x|c)的计算
private static double gaussianPro(double point,double mean,double variance){
double prob = 0.0;
prob = (1/(Math.sqrt(2*Math.PI)*Math.sqrt(variance)))*Math.exp(-(point-mean)*(point-mean)/(2*variance));
return prob;
}
/**
* E-step的主要逻辑
*/
private static double[][] countPostprob(double[] means,double[] variances,double[] points,double[] probs){
int clusterNum = means.length;
int pointNum = points.length;
double[][] postProbs = new double[clusterNum][pointNum];
double[] denominator = new double[pointNum];
for(int m = 0;m <pointNum;m++){
denominator[m] = 0.0;
for(int n = 0;n<clusterNum;n++){
denominator[m]+=(gaussianPro(points[m], means[n], variances[n])*probs[n]);
}
}
for(int i = 0;i<clusterNum;i++){
for(int j = 0;j<pointNum;j++){
postProbs[i][j]=(gaussianPro(points[j], means[i], variances[i])*probs[i])/(denominator[j]);
}
}
return postProbs;
}
private static void eStep(){
w = countPostprob(means, variances, points, probs);
}
/**
* M-step的主要逻辑之一:由E-step得到的期望,重新计算均值
*/
private static double[] guessMean(double[][] w,double[] points){
int wLength = w.length;
double[] means = new double[w.length];
double[] wi = new double[wLength];
for (int m = 0; m < wLength; m++) {
wi[m] = 0.0;
for(int n = 0; n<points.length;n++){
wi[m] += w[m][n];
}
}
for(int i = 0;i<w.length;i++){
means[i] = 0.0;
for(int j = 0;j<points.length;j++){
means[i]+=(w[i][j]*points[j]);
}
means[i] /= wi[i];
}
return means;
}
/**
* M-step的主要逻辑之一:由E-step得到的期望,重新计算方差
*/
private static double[] guessVariance(double[][] w,double[] points){
int wLength = w.length;
double[] means = new double[w.length];
double[] variances = new double[wLength];
double[] wi = new double[wLength];
for (int m = 0; m < wLength; m++) {
wi[m] = 0.0;
for(int n = 0; n<points.length;n++){
wi[m] += w[m][n];
}
}
means = guessMean(w, points);
for(int i = 0;i<wLength;i++){
variances[i] = 0.0;
for(int j = 0;j<points.length;j++){
variances[i] +=(w[i][j]*(points[j]-means[i])*(points[j]-means[i]));
}
variances[i] /= wi[i];
}
return variances;
}
/**
* M-step的主要逻辑之一:由E-step得到的期望,重新计算概率
*
*/
private static double[] guessProb(double[][] w){
int wLength = w.length;
double[] probs = new double[wLength];
for(int i = 0;i<wLength;i++){
probs[i] = 0.0;
for(int j = 0;j<w[i].length;j++){
probs[i]+=w[i][j];
}
probs[i] /=w[i].length;
}
return probs;
}
private static void mStep(){
means = guessMean(w, points);
variances = guessVariance(w, points);
probs = guessProb(w);
}
/**
* 计算前后两次迭代的参数的差值
*
*/
private static double threshold(double[] bef_values,double[] values){
double diff = 0.0;
for(int i = 0 ; i < values.length;i++){
diff += (values[i]-bef_values[i]);
}
return Math.abs(diff);
}
public static void main(String[] args)throws Exception{
int k = 2;
w = new double[k][points.length];
double[] bef_means;
double[] bef_var;
do{
bef_means = means;
bef_var = variances;
eStep();
mStep();
}while(threshold(bef_means, means)<0.01&&threshold(bef_var, variances)<0.01);
for(double prob:probs)
System.out.println(prob);
}
}
2.结果显示