描述
算法接受参数 k ;然后将事先输入的n个数据对象划分为 k个聚类以便使得所获得的聚类满足:同一聚类中的对象相似度较高;而不同聚类中的对象相似度较小。聚类相似度是利用各聚类中对象的均值所获得一个“中心对象”(引力中心)来进行计算的。
K-means算法是最为经典的基于划分的聚类方法,是十大经典数据挖掘算法之一。K-means算法的基本思想是:以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。
该算法的最大优势在于简洁和快速。算法的关键在于初始中心的选择和距离公式。
聚类
聚类分析是一种静态数据分析方法,常被用于机器学习,模式识别,数据挖掘等领域。通常认为,聚类是一种无监督式的机器学习方法,它的过程是这样的:在未知样本类别的情况下,通过计算样本彼此间的距离(欧式距离,马式距离,汉明距离,余弦距离等)来估计样本所属类别。从结构性来划分,聚类方法分为自上而下和自下而上两种方法,前者的算法是先把所有样本视为一类,然后不断从这个大类中分离出小类,直到不能再分为止;后者则相反,首先所有样本自成一类,然后不断两两合并,直到最终形成几个大类。
聚类算法
常用的聚类方法主要有以下四种:
- Connectivity based clustering (如hierarchical clustering 层次聚类法)
- Centroid-based clustering (如kmeans)
- Distribution-based clustering
- Density-based clustering
优缺点
Kmeans聚类是一种自下而上的聚类方法,它的优点是简单、速度快;缺点是聚类结果与初始中心的选择有关系,且必须提供聚类的数目。Kmeans的第二个缺点是致命的,因为在有些时候,我们不知道样本集将要聚成多少个类别,这种时候kmeans是不适合的,推荐使用hierarchical 或meanshift来聚类。第一个缺点可以通过多次聚类取最佳结果来解决。
算法流程
首先从n个数据对象任意选择 k 个对象作为初始聚类中心;
而对于所剩下其它对象,则根据它们与这些聚类中心的相似度(距离),分别将它们分配给与其最相似的(聚类中心所代表的)聚类;
然后再计算每个所获新聚类的聚类中心(该聚类中所有对象的均值);
不断重复这一过程直到标准测度函数开始收敛为止。
一般都采用均方差作为标准测度函数.
k个聚类具有以下特点:各聚类本身尽可能的紧凑,而各聚类之间尽可能的分开。
输入:k, data[n];
- 选择k个初始中心点,例如c[0]=data[0],…c[k-1]=data[k-1];
- 对于data[0]….data[n], 分别与c[0]…c[k-1]比较,假定与c[i]差值最少,就标记为i;
- 对于所有标记为i点,重新计算c[i]={ 所有标记为i的data[j]之和}/标记为i的个数;
- 重复(2)(3),直到所有c[i]值的变化小于给定阈值。
源码示例
python
自编
# coding=utf-8
'''
* k-Means算法,聚类算法;
* 实现步骤:1. 首先是随机获取总体中的K个元素作为总体的K个中心;
* 2. 接下来对总体中的元素进行分类,每个元素都去判断自己到K个中心的距离,并归类到最近距离中心去;
* 3. 计算每个聚类的平均值,并作为新的中心点
* 4. 重复2,3步骤,直到这k个中线点不再变化(收敛了),或执行了足够多的迭代
'''
from numpy import linalg as ll
import numpy as np
import random
from sklearn import cluster as sc
# train data
srcdata = [ [ 5.1, 3.5, 1.4, 0.2],
[ 4.9, 3.0, 1.4, 0.2 ],[ 4.7, 3.2, 1.3, 0.2 ],
[ 4.6, 3.1, 1.5, 0.2 ],[ 5.0, 3.6, 1.4, 0.2 ],
[ 7.0, 3.2, 4.7, 1.4 ],[ 6.4, 3.2, 4.5, 1.5 ],
[ 6.9, 3.1, 4.9, 1.5 ],[ 5.5, 2.3, 4.0, 1.3 ],
[ 6.5, 2.8, 4.6, 1.5 ],[ 5.7, 2.8, 4.5, 1.3 ],
[ 6.5, 3.0, 5.8, 2.2 ],[ 7.6, 3.0, 6.6, 2.1 ],
[ 4.9, 2.5, 4.5, 1.7 ],[ 7.3, 2.9, 6.3, 1.8 ],
[ 6.7, 2.5, 5.8, 1.8 ],[ 6.9, 3.1, 5.1, 2.3 ] ]
#print srcdata
srclen = len(srcdata)
len0 = len(srcdata[0])
# set k = 5
k = 5
# step 1
# index = range(k) -- use first k
# use random to get index.
index = random.sample( range( srclen ) , k)
#print index
c = np.arange(k*len0,dtype=float).reshape(k,len0)
for i in range(k):
for j in range(len0):
c[i][j] = srcdata[index[i]][j]
print c
# set stop threshold
delta = 0.001
dis = np.arange(k,dtype=float)
calindex = np.arange(srclen)
norm = 10.0
# step 4
while norm>delta:
# step 2
for i in range( srclen ):
for m in range(k):
dis[m] = 0.0
for j in range(len0):
for m in range(k):
dis[m] += (srcdata[i][j] - c[m][j])*(srcdata[i][j] - c[m][j])
calindex[i] = np.where(dis == np.min(dis) )[0][0]
# step 3
newc = np.zeros(k*len0).reshape(k,len0)
for xx in range(srclen):
for j in range(len0):
newc[calindex[xx]][j] += srcdata[xx][j]
for i in range(k):
size = len(np.where(calindex == i)[0])
#print size
for j in range(len0):
newc[i][j]/=size
norm = ll.norm(c-newc)
#print norm
c = newc.copy()
print c
print calindex
## sklearn -> k-means++
cc = sc.k_means(srcdata,k)
print cc
Java
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class Kmeans {
/**
* @param args
* @throws IOException
*/
public static List<ArrayList<ArrayList<Double>>>
initHelpCenterList(List<ArrayList<ArrayList<Double>>> helpCenterList,int k){
for(int i=0;i<k;i++){
helpCenterList.add(new ArrayList<ArrayList<Double>>());
}
return helpCenterList;
}
/**
* @param args
* @throws IOException
*/
public static void main(String[] args) throws IOException{
List<ArrayList<Double>> centers = new ArrayList<ArrayList<Double>>();
List<ArrayList<Double>> newCenters = new ArrayList<ArrayList<Double>>();
List<ArrayList<ArrayList<Double>>> helpCenterList = new ArrayList<ArrayList<ArrayList<Double>>>();
//读入原始数据
BufferedReader br=new BufferedReader(new InputStreamReader(new FileInputStream("wine.txt")));
String data = null;
List<ArrayList<Double>> dataList = new ArrayList<ArrayList<Double>>();
while((data=br.readLine())!=null){
//System.out.println(data);
String []fields = data.split(",");
List<Double> tmpList = new ArrayList<Double>();
for(int i=0; i<fields.length;i++)
tmpList.add(Double.parseDouble(fields[i]));
dataList.add((ArrayList<Double>) tmpList);
}
br.close();
//随机确定K个初始聚类中心
Random rd = new Random();
int k=3;
int [] initIndex={
59,71,48};
int [] helpIndex = {
0,59,130};
int [] givenIndex = {
0,1,2};
System.out.println("random centers' index");
for(int i=0;i<k;i++){
int index = rd.nextInt(initIndex[i]) + helpIndex[i];
//int index = givenIndex[i];
System.out.println("index "+index);
centers.add(dataList.get(index));
helpCenterList.add(new ArrayList<ArrayList<Double>>());
}
/*
//注释掉的这部分目的是,取测试数据集最后稳定的三个类簇的聚类中心作为初始聚类中心
centers = new ArrayList<ArrayList<Double>>();
for(int i=0;i<59;i++)
helpCenterList.get(0).add(dataList.get(i));
for(int i=59;i<130;i++)
helpCenterList.get(1).add(dataList.get(i));
for(int i=130;i<178;i++)
helpCenterList.get(2).add(dataList.get(i));
for(int i=0;i<k;i++){
ArrayList<Double> tmp = new ArrayList<Double>();
for(int j=0;j<dataList.get(0).size();j++){
double sum=0;
for(int t=0;t<helpCenterList.get(i).size();t++)
sum+=helpCenterList.get(i).get(t).get(j);
tmp.add(sum/helpCenterList.get(i).size());
}
centers.add(tmp);
}
*/
//输出k个初始中心
System.out.println("original centers:");
for(int i=0;i<k;i++)
System.out.println(centers.get(i));
while(true)
{
//进行若干次迭代,直到聚类中心稳定
for(int i=0;i<dataList.size();i++){
//标注每一条记录所属于的中心
double minDistance=99999999;
int centerIndex=-1;
for(int j=0;j<k;j++){
//离0~k之间哪个中心最近
double currentDistance=0;
for(int t=1;t<centers.get(0).size();t++){
//计算两点之间的欧式距离
currentDistance += ((centers.get(j).get(t)-dataList.get(i).get(t))/(centers.get(j).get(t)+dataList.get(i).get(t))) * ((centers.get(j).get(t)-dataList.get(i).get(t))/(centers.get(j).get(t)+dataList.get(i).get(t)));
}
if(minDistance>currentDistance){
minDistance=currentDistance;
centerIndex=j;
}
}
helpCenterList.get(centerIndex).add(dataList.get(i));
}
// System.out.println(helpCenterList);
//计算新的k个聚类中心
for(int i=0;i<k;i++){
ArrayList<Double> tmp = new ArrayList<Double>();
for(int j=0;j<centers.get(0).size();j++){
double sum=0;
for(int t=0;t<helpCenterList.get(i).size();t++)
sum+=helpCenterList.get(i).get(t).get(j);
tmp.add(sum/helpCenterList.get(i).size());
}
newCenters.add(tmp);
}
System.out.println("\nnew clusters' centers:\n");
for(int i=0;i<k;i++)
System.out.println(newCenters.get(i));
//计算新旧中心之间的距离,当距离小于阈值时,聚类算法结束
double distance=0;
for(int i=0;i<k;i++){
for(int j=1;j<centers.get(0).size();j++){
//计算两点之间的欧式距离
distance += ((centers.get(i).get(j)-newCenters.get(i).get(j))/(centers.get(i).get(j)+newCenters.get(i).get(j))) * ((centers.get(i).get(j)-newCenters.get(i).get(j))/(centers.get(i).get(j)+newCenters.get(i).get(j)));
}
//System.out.println(i+" "+distance);
}
System.out.println("\ndistance: "+distance+"\n\n");
if(distance==0)//小于阈值时,结束循环
break;
else//否则,新的中心来代替旧的中心,进行下一轮迭代
{
centers = new ArrayList<ArrayList<Double>>(newCenters);
//System.out.println(newCenters);
newCenters = new ArrayList<ArrayList<Double>>();
helpCenterList = new ArrayList<ArrayList<ArrayList<Double>>>();
helpCenterList=initHelpCenterList(helpCenterList,k);
}
}
//输出最后聚类结果
for(int i=0;i<k;i++){
System.out.println("\n\nCluster: "+(i+1)+" size: "+helpCenterList.get(i).size()+" :\n\n");
for(int j=0;j<helpCenterList.get(i).size();j++)
{
System.out.println(helpCenterList.get(i).get(j));
}
}
}
}
test
测试数据集来源:wine数据集from UCI
以下是某次的运行结果,可以看出聚类结果与参考结果比较吻合
random centers' index
index 4
index 102
index 166
original centers:
[1.0, 13.744745762711865, 2.0106779661016954, 2.455593220338984, 17.037288135593222, 106.33898305084746, 2.8401694915254234, 2.982372881355932, 0.29, 1.8993220338983055, 5.528305084745763, 1.0620338983050848, 3.1577966101694916, 1115.7118644067796]
[2.0, 12.278732394366198, 1.932676056338028, 2.244788732394365, 20.238028169014086, 94.54929577464789, 2.2588732394366198, 2.080845070422536, 0.363661971830986, 1.6302816901408452, 3.08661971830986, 1.0562816901408452, 2.785352112676055, 519.5070422535211]
[3.0, 13.153749999999997, 3.3337500000000007, 2.4370833333333333, 21.416666666666668, 99.3125, 1.6787500000000002, 0.7814583333333331, 0.44749999999999995, 1.1535416666666667, 7.396249979166668, 0.6827083333333334, 1.6835416666666658, 629.8958333333334]
new clusters' centers:
[1.040983606557377, 13.695245901639344, 1.9844262295081967, 2.446475409836067, 17.199999999999996, 106.13934426229508, 2.8595901639344263, 3.0022950819672127, 0.28737704918032797, 1.900901639344262, 5.5009836065573765, 1.062131147540984, 3.1613934426229497, 1101.0]
[2.0, 12.272352941176464, 1.9286764705882358, 2.2503676470588228, 20.209558823529413, 94.72058823529412, 2.229705882352941, 2.0449999999999995, 0.3635294117647058, 1.6292647058823533, 3.0122794117647054, 1.0586176470588233, 2.7795588235294093, 517.6102941176471]
[2.9693877551020407, 13.146530612244899,