EM算法matlab和Java实现

EM算法具体过程看前一篇博客

一、matlab实现

1.matlab代码

close all;
clear;
clc;

%% 
M=3;          % 高斯数量
N=600;        % 数据样本总数
th=0.000001;  % 聚合阀值
K=2;          % 输出信号保留

% 待生成数据的参数
a_real =[2/3;1/6;1/6];
mu_real=[3 4 6;
         5 3 7];
cov_real(:,:,1)=[5 0;
                 0 0.2];
cov_real(:,:,2)=[0.1 0;
                 0 0.1];
cov_real(:,:,3)=[0.1 0;
                 0 0.1];                     
% 这里生成的数据全部符合标准
x=[ mvnrnd( mu_real(:,1) , cov_real(:,:,1) , round(N*a_real(1)) )' ,...
    mvnrnd( mu_real(:,2) , cov_real(:,:,2) , round(N*a_real(2)) )' ,...
    mvnrnd( mu_real(:,3) , cov_real(:,:,3) , round(N*a_real(3)) )' ];

figure(1),plot(x(1,:),x(2,:),'.')

%% EM 
% 参数初始化
a=[1/3,1/3,1/3]; %各类的比例(权重)
mu=[1 2 3;       %均值初始化
    2 1 4];
cov(:,:,1)=[1 0; %协方差初始化
            0 1];
cov(:,:,2)=[1 0;
            0 1];
cov(:,:,3)=[1 0;
            0 1];

t=inf;
count=0;
figure(2),hold on
while t>=th
    a_old  = a;
    mu_old = mu;
    cov_old= cov;      
    rznk_p=zeros(M,N);%生成M行N列零矩阵
    for cm=1:M
        mu_cm=mu(:,cm);
        cov_cm=cov(:,:,cm);
        for cn=1:N  %计算Pi(x)
            p_cm=exp(-0.5*(x(:,cn)-mu_cm)'/cov_cm*(x(:,cn)-mu_cm));
            rznk_p(cm,cn)=p_cm;
        end
        rznk_p(cm,:)=rznk_p(cm,:)/sqrt(det(cov_cm));
    end
    rznk_p=rznk_p*(2*pi)^(-K/2);
%E step
    %开始求rznk 相当于Pr(i|Xt)
    rznk=zeros(M,N);%r(Z
    pikn=zeros(1,M);%r(Z
    pikn_sum=0;
    for cn=1:N
        for cm=1:M%计算p(x|*)概率分布
            pikn(1,cm)=a(cm)*rznk_p(cm,cn);
%           pikn_sum=pikn_sum+pikn(1,cm);
        end
        for cm=1:M
            rznk(cm,cn)=pikn(1,cm)/sum(pikn);
        end
    end
        %求rank结束
% M step
    nk=zeros(1,M);
    for cm=1:M
        for cn=1:N
            nk(1,cm)=nk(1,cm)+rznk(cm,cn);
        end
    end
    a=nk/N;
    rznk_sum_mu=zeros(M,1);
        
    % 求均值MU
    %nk(cm)  相当于ni
    %rznk_sum_mu  ni*Xt
    for cm=1:M
        rznk_sum_mu=0;
        for cn=1:N
            rznk_sum_mu=rznk_sum_mu+rznk(cm,cn)*x(:,cn);
        end
        mu(:,cm)=rznk_sum_mu/nk(cm);
    end
    
    % 求协方差COV   
    for cm=1:M
        rznk_sum_cov=zeros(K,K);
        for cn=1:N%求协方差(Hi-u)^2
            rznk_sum_cov=rznk_sum_cov+rznk(cm,cn)*(x(:,cn)-mu(:,cm))*(x(:,cn)-mu(:,cm))';
        end
        cov(:,:,cm)=rznk_sum_cov/nk(cm);
    end
         
    t=max([norm(a_old(:)-a(:))/norm(a_old(:));norm(mu_old(:)-mu(:))/norm(mu_old(:));norm(cov_old(:)-cov(:))/norm(cov_old(:))]);
 
    temp_f=sum(log(sum(pikn)));
    plot(count,temp_f,'r+')
    count=count+1;        
end  %while 
   
hold off
f=sum(log(sum(pikn)));
  
% 输出结果
a
mu
cov

figure(3),
hold on
plot(x(1,:),x(2,:),'k.');
plot(mu_real(1,:),mu_real(2,:),'*c');
plot(mu(1,:),mu(2,:),'+r');
hold off

figure(4),
hold on
for i=1:N
    [max_temp,ind_temp]=max(rznk(:,i));
    if ind_temp==1
        plot(x(1,i),x(2,i),'k.');
    end
    if ind_temp==2
        plot(x(1,i),x(2,i),'b.');
    end
    if ind_temp==3
        plot(x(1,i),x(2,i),'r.');
    end    
end
        
        
%fcm聚类
[center, U, OBJ_FCN]=fcm(x',3);
figure(5),
hold on
for i=1:N
    [max_temp,ind_temp]=max(U(:,i));
    if ind_temp==1
        plot(x(1,i),x(2,i),'k.');
    end
    if ind_temp==2
        plot(x(1,i),x(2,i),'b.');
    end
    if ind_temp==3
        plot(x(1,i),x(2,i),'r.');
    end    
end
        
plot(center(:,1),center(:,2),'c*')

hold off

2.结果显示
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

二、Java实现

编译工具为:eclipse
1.Java代码

package cn.sxt.oo2;



/**
 * 一維情況下的EM算法實現
 *1、求期望(e-step)
 *2、期望最大化(估值)(M-step)
 *3、循環以上兩部直到收斂
 */
public class MyEM {
   private static final double[] points={1.0,1.3,2.2,2.6,2.8,5.0,7.3,7.4,7.5,7.7,7.9};
   private static double[][] w;//权值
   private static double[] means = {7.7,2.3};//均值
   private static double[] variances= {1,1};//方差
   private static double[] probs = {0.5,0.5};//每个类的概率;这里默认选择k=2了;
   
   /**
    * 高斯分布计算公式,也就是先验概率
   
    */
   //p(x|c)的计算
   private static double gaussianPro(double point,double mean,double variance){
	   double prob = 0.0;
	   prob = (1/(Math.sqrt(2*Math.PI)*Math.sqrt(variance)))*Math.exp(-(point-mean)*(point-mean)/(2*variance));
	   return prob;
   }
   /**
    * E-step的主要逻辑
    
    */
   private static double[][] countPostprob(double[] means,double[] variances,double[] points,double[] probs){
       int clusterNum = means.length;
       int pointNum = points.length;
	   double[][] postProbs = new double[clusterNum][pointNum];	
	   double[] denominator = new double[pointNum];
	   for(int m = 0;m <pointNum;m++){
		   denominator[m] = 0.0;
		   for(int n = 0;n<clusterNum;n++){
			   denominator[m]+=(gaussianPro(points[m], means[n], variances[n])*probs[n]);
		   }
	   }
	   for(int i = 0;i<clusterNum;i++){
		   for(int j = 0;j<pointNum;j++){
			   postProbs[i][j]=(gaussianPro(points[j], means[i], variances[i])*probs[i])/(denominator[j]);
		   }
	   }
        return postProbs;
   }
   private static void  eStep(){
	   w = countPostprob(means, variances, points, probs);
   }
   /**
    * M-step的主要逻辑之一:由E-step得到的期望,重新计算均值
    */
   private static double[] guessMean(double[][] w,double[] points){
	  
	   int wLength = w.length;
	   double[] means = new double[w.length];
	   double[] wi = new double[wLength];
	   for (int m = 0; m < wLength; m++) {
		   wi[m] = 0.0;
		for(int n = 0; n<points.length;n++){
			wi[m] += w[m][n];
		}
	  }
	   for(int i = 0;i<w.length;i++){
		   means[i] = 0.0;
		   for(int j = 0;j<points.length;j++){
			   means[i]+=(w[i][j]*points[j]);
		   }
		   means[i] /= wi[i];
	   }
	   return means;
   }
   /**
    * M-step的主要逻辑之一:由E-step得到的期望,重新计算方差
    
    */
   private static double[] guessVariance(double[][] w,double[] points){
	   int wLength = w.length;
	   double[] means = new double[w.length];
	   double[] variances = new double[wLength];
	   double[] wi = new double[wLength];
	   for (int m = 0; m < wLength; m++) {
		   wi[m] = 0.0;
		for(int n = 0; n<points.length;n++){
			wi[m] += w[m][n];
		}
	  }
	   means = guessMean(w, points);
	   for(int i = 0;i<wLength;i++){
		   variances[i] = 0.0;
		   for(int j = 0;j<points.length;j++){
			   variances[i] +=(w[i][j]*(points[j]-means[i])*(points[j]-means[i])); 
		   }
		   variances[i] /= wi[i];
	   }
	   
	   return variances;
   }
   /**
    * M-step的主要逻辑之一:由E-step得到的期望,重新计算概率
    *
    */
   private static double[] guessProb(double[][] w){
	   int wLength = w.length;
	   double[] probs = new double[wLength];
	   for(int i = 0;i<wLength;i++){
		   probs[i] = 0.0;
		   for(int j = 0;j<w[i].length;j++){
			   probs[i]+=w[i][j];
		   }
		   probs[i] /=w[i].length;
	   }
	   return probs;
   }
   private static void mStep(){
	   means = guessMean(w, points);
	   variances = guessVariance(w, points);
	   probs = guessProb(w);
   }
   /**
    * 计算前后两次迭代的参数的差值
    * 
    */
   private static double threshold(double[] bef_values,double[] values){
	   double diff = 0.0;
	   for(int i = 0 ; i < values.length;i++){
		   diff += (values[i]-bef_values[i]);
	   }
	   return Math.abs(diff);
   }
   public static void main(String[] args)throws Exception{
	   
	   int k = 2;
	   w = new double[k][points.length];
	   double[] bef_means;
	   double[] bef_var;
	   do{
		  bef_means = means;
		  bef_var = variances;
	      eStep();
	      mStep();
	   }while(threshold(bef_means, means)<0.01&&threshold(bef_var, variances)<0.01);
	   for(double prob:probs)
	       System.out.println(prob);   
   }
}

2.结果显示
在这里插入图片描述

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

唐维康

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值