看了西关书的聚类算法,算法原理很容易明白,接下来就是整理成自己的理解思路,然后一步一步来实现算法,那么就来做吧。
下载连接:点击
ok啦,接下来不废话上代码(Matlab发布形式)
%Fucntion :Mixture of Guassian Clustering
%data:
%autor:try
function Main()
clc;clear;close all
melon_data =load('melon4.0.txt');
melon_data(:,1)=[];
global k;
k = 3;
likelihold = 0;
likeli_value =[];
%step1
[mu,sigma,alpha]= init_guassian_para(melon_data,k);% mu:3x2, alpha:3x1,sigma:6x2
%main loop
max_loop = 100;
for loop =1:max_loop
%step2
[post_prob,class_box,guass_value]= byes_posterior(melon_data,mu,sigma,alpha);
%step3
[mu,sigma,alpha] = update_guassian_para(melon_data,post_prob);
%step4
cur_likelihold = calcu_likelihold(melon_data,guass_value,alpha);
[is_stop,likelihold] = is_stop_iteration(cur_likelihold,likelihold,loop);
likeli_value = [likeli_value;likelihold];
if is_stop ||loop >max_loop
show(mu,class_box);
fprintf('迭代次数为: %d\n',loop);
break;
end
end
figure;plot(likeli_value,'-r');
end
subfunction
%step1
function [mu,sigma,alpha]= init_guassian_para(melon_data,k)
[m,n] =size(melon_data);
alpha = ones(1,k).*(1/k);
mu = ones(k,n);
mu(1,:) = melon_data(6,:);
mu(2,:) = melon_data(22,:);
mu(3,:) = melon_data(27,:);
temp_sigma = diag([0.1,0.1]);
sigma = [temp_sigma;temp_sigma;temp_sigma];
end
%step2
function [post_prob,class_box,guass_value]= byes_posterior(melon_data,mu,sigma,alpha)
%input:x:1x2,mu:3x2,sigma:6x2,alpha :1x3
global k;
[m,~] =size(melon_data);
guass_prob = zeros(1,k);
post_prob = zeros(m,k);
class_box = zeros(m,2*k);
guass_value = [];
for i = 1:m
for j = 1:k
guass_prob(1,j)= guassian_prob_func(melon_data(i,:),mu(j,:),sigma(2*j-1:2*j,:));
end
post_prob(i,:) = alpha.*guass_prob/(sum(alpha.*guass_prob));
[~,max_ind] = max(post_prob(i,:));
zero_ind = find(class_box(:,2*max_ind)==0);
class_box(zero_ind(1),max_ind*2-1:max_ind*2) = melon_data(i,:);
guass_value = [guass_value;guass_prob];
end
end
function guassian_val = guassian_prob_func(x,mu,sigma)
%input:x:1x2,mu:1x2,sigma:2x2
%使用公式9-28,使用矩阵的方式不好求,不好设计,使用逐个相求
n=2;
x =x';
mu = mu';
coeff = 1/((2*pi)^(n/2)*det(sigma)^0.5);
guassian_val = coeff*exp(-0.5*(x-mu)'*inv(sigma)*(x-mu));
end
function [mu,sigma,alpha] = update_guassian_para(melon_data,post_prob)
[m,n] =size(melon_data);
[~,k] = size(post_prob);
alpha = zeros(1,k);
mu = ones(k,n);
sigma = zeros(k*n,n);
for i = 1:k
mu(i,:) = sum(melon_data.*repmat(post_prob(:,i),1,2),1)/sum(post_prob(:,i));
sigma(2*i-1:2*i,:) = (melon_data-repmat(mu(i,:),m,1))'.*repmat(post_prob(:,i)',2,1)*(melon_data-repmat(mu(i,:),m,1))/sum(post_prob(:,i));
alpha(i) = sum(post_prob(:,i))/m;
end
end
function likelihold = calcu_likelihold(melon_data,guass_value,alpha)
likelihold = 0;
[m,n] =size(melon_data);
for j =1:m
likelihold = likelihold +log(alpha*guass_value(j,:)');
end
end
function [is_stop,LLD] = is_stop_iteration(cur_likelihold,likelihold,loop)
is_stop = 0;
if loop==1
is_stop = 0;
elseif loop>1 &&(abs(cur_likelihold - likelihold)<0.001)
is_stop = 1;
else
is_stop = 0;
end
LLD = cur_likelihold;
end
function show(mu,class_box)
%mu:3x2,class_box:30*6
plot(mu(1,1),mu(1,2),'r^');hold on;
plot(class_box(:,1),class_box(:,2),'y^');
plot(mu(2,1),mu(2,2),'rs');
plot(class_box(:,3),class_box(:,4),'gs');
plot(mu(3,1),mu(3,2),'ro');
plot(class_box(:,5),class_box(:,6),'bo');
xlabel('density');ylabel('sugar rate');
end
迭代次数为: 48