# 基于K均值(KMeans)与EM法估计高斯混合模型(GMM)参数的图像分割matlab程序

function [sg,normMuDiff]=KMeansEMSeg(img,k,Mu,eps)
normMuDiff=[];
m=size(img,1);
n=size(img,2);
c=size(img,3);
N=m*n;
img=double(img);
M=reshape(img,m*n,c);
% index=randperm(m*n);
% Mu=M(index(1:k),:);
flag=1;
pc=rand(1,k); pc=pc/sum(pc); % 各类总的概率归一化
Cov=repmat(eye(c),1,1,k)*30^2; %k个聚类的协方差
C=zeros(c,c,k); %协方差的逆矩阵
snormCov=zeros(1,k);%各聚类协方差范数平方根
w=zeros(m*n,k); %各像素点对于聚类的权重矩阵
Label=zeros(m*n,1);

iter=1;
figure;
while flag
old_Mu=Mu;
%--------------------E步------------------------------%
%估计每个点属于各类的概率
for j=1:k
%         Cov(:,:,j)=Cov(:,:,j)+5^2;
C(:,:,j)  =inv(Cov(:,:,j));%使用Mahalanobis距离(即马氏距离)度量
snormCov(j)=sqrt(norm(Cov(:,:,j)));
end
for i=1:N
dis=zeros(k,1);
for j=1:k
dis(j)=(M(i,:)-Mu(j,:))*C(:,:,j)*(M(i,:)-Mu(j,:))';
w(i,j)=pc(j)*exp(-dis(j)/2)/snormCov(j);
w(i,j)=max(w(i,j),10^-5); %概率截断阈值化，防止矩阵奇异造成计算误差骤增第一重保险（阈值建议取值10^-5~10^-4）
end
end
w=w./repmat(sum(w,2),1,k); %各点关于各类的概率归一化
%--------------------M步----------------------------%
%估计各类均值、协方差及该类总的概率
for j=1:k
pc(j)=sum(w(:,j));
Mu(j,:)=w(:,j)'*M/pc(j);    %估计均值
D=M-repmat(Mu(j,:),N,1);
W=sparse(1:N,1:N,w(:,j),N,N); %构造稀疏权重矩阵
Cov(:,:,j)=(1/pc(j))*D'*W*D;

%绝对值截断阈值化，防止矩阵奇异的第二重保险，该例建议阈值取值范围1~10（也可去掉第二重保险）
temp=Cov(:,:,j);
index=find(abs(temp)<10);
temp(index)=sign(temp(index)).*10;
Cov(:,:,j)=temp;
end
pc=pc/sum(pc); %各类总的概率归一化

%记录误差值，达到设定误差限停止循环
normDiff=norm(old_Mu-Mu);
if  normDiff<eps
flag=0;
end
normMuDiff=[normMuDiff;normDiff];

[~,Label] = max(w');
sg=reshape(Label,m,n);

% 录制gif
imshow(mat2gray(sg));
F=getframe(gcf);
I=frame2im(F);
[I,map]=rgb2ind(I,256);
if iter == 1
imwrite(I,map,'test_KMeansEM.gif','gif','Loopcount',inf,'DelayTime',0.2);
else
imwrite(I,map,'test_KMeansEM.gif','gif','WriteMode','append','DelayTime',0.2);
end
iter=iter+1;
end


clear all;close all;clc;
% 初始值是在图中黄、绿、红、白、紫部分随机取一点得到的RGB值
Mu=[255 220 50;...
115 124 31;...
216 50 39;...
247 215 185;...
75 45 75];
[sg,normMuDiff]=KMeansEMSeg(img,5,Mu,1);
figure,imshow(mat2gray(sg));
figure,plot(1:length(normMuDiff),normMuDiff,'g-');
title('误差变化曲线');

10-22
09-01 1万+

11-07 3090
02-01 7193
07-25 1269
03-14 1万+
08-19 1万+
04-28 6344