看了西关书的聚类算法,算法原理很容易明白,接下来就是整理成自己的理解思路,然后一步一步来实现算法,那么就来做吧。
k-means cluster 算法步骤
输入:txt样本数据;输出:样本的归属类
step1:初始化工作
输入,样本,k;输出:初始均值向量
step1.1:载入数据,选择k值,初始化存放每个类的容器
repeat
step2:根据距离相似度划分所有样本对应的类 (for 样本数量)
step2.1:更新样本所属类:计算样本xi与各均值向量的距离,根据距离最近原则将样本xi放入对应类容器中
step2.2:计算更新后的均值向量:求均值即可
step3判断是否应该停止更新(for k的数量)
step3.1若ui’!=ui,将ui更新为ui’,否则保持当前均值向量不变
step3.2若所有的均值向量均为更新,停止程序
end repeat
ok啦,接下来不废话上代码(Matlab发布形式)
K-Means代码及样本下载链接
function Main()
%funtion:k-means cluster
%Reference: Machine leaning by zhihua zhou
clc
clear
close all
%1.0initial
melon_data = load('melon4.0.txt');
plot(melon_data(:,1),melon_data(:,2),'+')
xlabel('density');ylabel('sugar rate');
melon_data(:,1)=[];
[m,n] = size(melon_data);k=3;
sample_vector = zeros(m,n,k);
cluster_vector = zeros(k,n);
cluster_vector = init_cluster_vector(k,melon_data);
%
%show original
plot(melon_data(:,1),melon_data(:,2),'+');hold on
plot(cluster_vector(:,1),cluster_vector(:,2),'pg');
xlabel('density');ylabel('sugar rate');
%}
%2.0 main loop
max_loop = 10;
for loop = 1:max_loop
sample_vector = cluster(melon_data,cluster_vector);%2.1
cur_cluster_vector = update_cluter_vector(sample_vector);%2.2
show(sample_vector,cur_cluster_vector);
%is_stop = find(pdist(cur_cluster_vector-cluster_vector)>0.1);
if is_stop(cur_cluster_vector,cluster_vector)
fprintf('循环次数为:%d\n',loop);
break;
else
cluster_vector= cur_cluster_vector;
end
end
end
subfunction
%1初始均值向量
function cluster_vector = init_cluster_vector(k,melon_data)
cluster_vector =[];
k_rows = ceil( size(melon_data,1)/k);
start = 1;
for i =1:k
if i<k
tem_vector = sum( melon_data(start:i*k_rows,:))/(i*k_rows-start+1);
start = i*k_rows;
else
tem_vector = sum( melon_data(start:end,:))/(i*k_rows-start+1);
end
cluster_vector = [cluster_vector;tem_vector];
end
end
%2.1更新样本所属类
function sample_vector = cluster(melon_data,cluster_vector)
% return : m,n,k
[m,n]= size(melon_data);
k = size(cluster_vector,1);
sample_vector =zeros(m,n,k);
for i=1:length(melon_data)
dist = pdist2(cluster_vector,melon_data(i,:));
[~,min_ind]=min(dist);
append_row =find(sample_vector(:,:,min_ind)==0);
sample_vector(append_row(1),:,min_ind) =melon_data(i,:);
end
end
%2.2更新均值向量
function cur_cluster_vector = update_cluter_vector(sample_vector)
[m,n,k]=size(sample_vector);
cur_cluster_vector = zeros(k,n);
for i = 1:k
zero_ind = find(sample_vector(:,:,i)==0);
nums = zero_ind(1)-1;
cur_cluster_vector(i,:) = sum(sample_vector(:,:,i),1)/nums;
end
end
function show(sample_vector,cur_cluster_vector)
close
plot(sample_vector(:,1,1),sample_vector(:,2,1),'+r');hold on
plot(sample_vector(:,1,2),sample_vector(:,2,2),'+g');
plot(sample_vector(:,1,3),sample_vector(:,2,3),'+b');
plot(cur_cluster_vector(1,1),cur_cluster_vector(1,2),'pr');
plot(cur_cluster_vector(2,1),cur_cluster_vector(2,2),'pg');
plot(cur_cluster_vector(3,1),cur_cluster_vector(3,2),'pb');
xlabel('density');ylabel('sugar rate');
end
function isStop =is_stop(cur_cluster_vector,cluster_vector)
isStop = 0;
if cur_cluster_vector(1,1)==cluster_vector(1,1)&&...
cur_cluster_vector(1,2)==cluster_vector(1,2)&&...
cur_cluster_vector(2,1)==cluster_vector(2,1)&&...
cur_cluster_vector(2,2)==cluster_vector(2,2)&&...
cur_cluster_vector(3,1)==cluster_vector(3,1)&&...
cur_cluster_vector(3,2)==cluster_vector(3,2)
isStop = 1;
end
end
循环次数为:3