k-means是一种最常用的聚类算法。关于k-means算法的介绍到处都能找到,并且比较容易理解。mahout里面也实现了k-means算法。下面贴出的是自己写的实现。目的是帮助大家能更清楚的认识和更快的掌握k-means算法。
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
/**
*
* @author aturbo
* 1、随机选择k个点作为中心点(centroid)
* 2、计算点到各个类中心的距离;
* 3、将点放入最近的中心点所在的类
* 4、重新计算中心点
* 5、判断目标函数是否收敛,收敛停止,否则循环2-4步
*
*/
public class MyKmeans {
public static final double[][] points = { { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 }, { 3, 3 }, { 8, 8 }, { 9, 8 },
{ 8, 9 }, { 9, 9 } };
/**
* 随机选择k个点作为中心点
* @param k
* @return k个中心点
*/
private static double[][] chooseinitK(int k){
double[][] cluster = new double[k][];
Set set = new HashSet<Integer>();
for(int i = 0;i<points.length;i++){
set.add(i);
}
//在set中剩下的序列点就为随机选择的
for(int i = 0;i<(points.length-k);){
Random random = new Random();
int a = random.nextInt(points.length);
if(!set.contains(a))
continue;
set.remove(a);
System.out.println("a"+a);
i++;
}
Iterator<Integer> iterator = set.iterator();
int j =0;
while(iterator.hasNext()){
cluster[j]=points[iterator.next()];
j++;
}
for(int i = 0;i<cluster.length;i++){
System.out.println("随机选择的k个节点:"+cluster[i][0]+"\t"+cluster[i][1]);
}
return cluster;
}
/**
* 欧式距离计算公式
* @param center (中心)点
* @param otherpoint
* @return 欧式距离
*/
private static double eurDistance(double[] center,double[] otherpoint){
double distance=0.0;
for(int i = 0;i<center.length;i++){
distance += ((center[i]-otherpoint[i])*(center[i]-otherpoint[i]));
}
distance = Math.sqrt(distance);
return distance;
}
/**
* 目标函数——也就是每个聚类中的点到它中心点的距离和
* @param center 中心点
* @param cluster 划分的(中间)聚类
* @return cost
*/
private static double cost(double[][] center,List<double[]>[] cluster){
double cost = 0.0;
for(int i = 0;i<cluster.length;i++){
for(int j = 0;j<cluster[i].size();j++){
double tempCost = 0.0;
for(int k = 0;k<center.length;k++){
System.out.println(cluster[i].get(j)[k]);
tempCost += (cluster[i].get(j)[k]-center[i][k])*(cluster[i].get(j)[k]-center[i][k]);
}
cost+=Math.sqrt(tempCost);
}
}
return cost;
}
/**
* 聚类算法——将所有点和各中心点计算距离,将点放入最近距离点的类中
* @param points 所有点
* @param centers 中心点
* @param k
* @return 聚类
*/
private static List<double[]>[] returnCluster(double[][] points,double[][] centers,int k){
List[] cluster = new ArrayList[k];
for(int i = 0;i<cluster.length;i++){
cluster[i] = new ArrayList<Double[]>();
}
for(double[] point:points){
double min_distance = Double.MAX_VALUE;
int clusterNum = 0;
int flag = 0;
double distance =0.0;
for(double[] center:centers){
distance = eurDistance(center, point);
if(distance<min_distance){
flag = clusterNum;
min_distance = distance;
}
clusterNum++;
}
cluster[flag].add(point);
}
return cluster;
}
/**
* 计算类的中心点的坐标
* @param cluster 聚类
* @return
*/
private static double[][] countCenter(List<double[]>[] cluster){
double x = 0.0;
double y = 0.0;
int k = cluster.length;
double[][] initk =new double[k][2];
for(int i = 0;i<cluster.length;i++){
for(int j = 0;j<cluster[i].size();j++){
x += cluster[i].get(j)[0];
y += cluster[i].get(j)[1];
}
x = x/cluster[i].size();
y = y/cluster[i].size();
initk[i][0]=x;
initk[i][1]=y;
}
return initk;
}
public static void main(String[] args){
int k = 2;
double[][] initk = chooseinitK(k);
double minCost = Double.MAX_VALUE;
double tempCost = Double.MAX_VALUE;
List[] cluster;
do{
minCost = tempCost;
cluster = returnCluster(points, initk, k);
initk = countCenter(cluster);
tempCost = cost(initk, cluster);
}while(tempCost<minCost);//当目标函数收敛后,停止
}
}