k-means聚类算法是聚类算法中应用非常广泛的一种算法。它是属于划分法的一种,是一种基于距离的聚类方法,在聚类的开始需要指定一个K值,表示需要聚类的数目。
k-means聚类算法的思想非常容易理解:拿到待聚类的N个样本和需要聚类的数目K(K<N)(k怎么选择后边会介绍)以后,随机的在N个样本中选择K个点作为初始的聚类中心,然后计算剩余的点与每一个聚类中心点的距离(当然有许多种计算距离的方法,下边会介绍)并选择与自身距离最小的聚类中心点为聚类中心(也就是选择与自身距离最小的那个聚类中心点同一个类别),这样第一次迭代结束,接下来就是看这次的迭代能否达到你设定的目标(就是迭代终止条件,后边会介绍),若是达到了那么聚类结束,否则的话进行下一轮的迭代。下一轮的迭代其实就是重新计算聚类中心点(有可能不是你的样本数据点),然后计算其他点与新聚类中心的距离,重新选择类别。就这样依次迭代直到达到设定的终止条件。用算法的形式表示为:
设定聚类数目K;
在N个样本中随机的选择K个样本点作为初始的聚类中心点;
repeat
计算剩余的点和每一个聚类中心点的距离并选择与自身距离最小的聚类中心点为聚类中心;
重新计算每一个聚类的聚类中心;
until 迭代的终止条件
举个具体的例子来说明:
例:我有平面上的8(N=8)个样本点:p0(2,3)、p1(2,5)、p2(3,5)、p3(4,6)、p4(5,4)、p5(3,1)、p6(0,6)、p7(7,7),我想聚成3个类别。
按照上边的算法讲的步骤:
我设定聚类的数目K=3;
之后我在这8个样本中随机的选择3个点为初始的聚类中心点,比方说我选择了p1、p4、p5为初始的聚类中心点。
repeat
剩余的点有p0、p2、p3、p6、p7,分别计算他们与聚类中心点p1、p4、p5距离,这里我们选择距离的计算方法为欧氏距离。通过计算(拿p0为例)可得
(p0,p1)=2 (p0,p4)=sqrt(10) (p0,p5)=sqrt(5),由距离我们可以得到p0和p1的距离最小,所以将p0放到p1所在的类中。 用同样的方法我们可以
将p2、p3、p6、p7分别放到p1、p1、p1、p4。这样我们得到第一轮的三个类簇c1(p0,p1,p2,p3,p6),c2(p4,p7),c3(p5)。接下来我们重新计算聚
类中心点,用的方法就是取类簇中所有样本的均值,以类簇c2为例计算可得c2的新的聚类中心为((5+7)/2,(4+7)/2)=(6,5.5),同样的方法可以得到其 他
两个类簇的新的聚类中心c1=(2.2,5),c3=(3,1)。
注意(只是说第二轮,以后的每一轮一次类推):若是下一轮有迭代,剩余点为p0、p1、p2、p3p4、p6、 p7, 因为p5依然是聚类中心点。
util 聚类中心不变或者聚类中心的变化量小于某一个阈值或者达到迭代次数。
k-means聚类算法中有几个关键点需要注意一下:
K值的选择:
在实际的应用中k值的选择一般是靠经验来选择的,多试几次,选择其中对你的所要解决的问题最好的聚类数目。但是在学术上或者其他一些博客中也有一些对K值选择
的算 法,不过本人没有太去研究,不好做出评论。在这篇博客中点击打开链接,有对K值的选择和对初始聚类中心的选择的介绍,有兴趣的可以去看一下。
初始聚类中心点的选择:
工程上对于初始聚类中心点的选择,一般使用随机选择然后去迭代,都能够取的不错的效果。当然,也可以去参照K值选择中的那篇博客去求解。
与聚类中心距离的计算:
在工程中我使用的比较多的是欧氏距离和余弦夹角,这个想必不用介绍了,大家都应该清楚的。
迭代条件:
迭代条件关系到你的聚类的效果的问题,一般迭代条件会是三种情况:第一种是你设定迭代的次数(50、100或者更多。。。。自己设定),当迭代次数达到你所设定
次数,聚类自动终止;第二种是聚类中心不变,也就是上一轮迭代和下一轮迭代的所有的聚类中心都不变化了,这种情况下一般难以收敛,通常会和迭代次数一起使用
;第三种情况就是设定一个阈值,当所有的聚类中心点的变化范围都不大于你设置的阈值的时候,聚类结束,也可以与迭代次数一块使用。
好了,上边讲的这么多都是从我的实际应用中总结的,若是其他人有更好的见解,请不吝赐教。下边上代码(Java版,我可能会在写一份Python的供大家使用):
主类(Cluster_Kmeans.java):
package com.zc.test;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import com.zc.source.kmeans;
import com.zc.source.kmeans_data;
public class Cluster_Kmeans {
public static void main(String[] args) {
ArrayList<double[]> dimension = new ArrayList<double[]>();//存放样本词的向量值
List<String> dataSet = new LinkedList<String>();//存放样本中的词
File Input = new File("E:\\test_csdn\\test.txt");// 输入数据
File Output = new File("E:\\test_csdn\\rel.txt");// 输出结果
if (Output.exists()) {
Output.delete();
}
FileInputStream fis = null;
InputStreamReader isr = null;
BufferedReader br = null;
FileOutputStream fos = null;
OutputStreamWriter osw = null;
BufferedWriter bw = null;
try {
fis = new FileInputStream(Input);
isr = new InputStreamReader(fis, "utf-8");
br = new BufferedReader(isr);
fos = new FileOutputStream(Output);
osw = new OutputStreamWriter(fos, "utf-8");
bw = new BufferedWriter(osw);
String line = br.readLine();
String s[] = null;
while (line != null) {
double b[] = new double[200];
s = line.split(" ");//样本中的数据包词语和对应的向量值,中间使用空格隔开的,所以取出来的时候要用空格划分
for (int i = 1; i < s.length; i++) {
b[i - 1] = Double.parseDouble(s[i]);
}
dimension.add(b);//放入某个词的向量值
dataSet.add(s[0]);//放入词语
line = br.readLine();
}
System.out.println("数据加载完毕---------------");
double[][] ff = new double[dataSet.size()][200];
for (int i = 0; i < dimension.size(); i++) {
ff[i] = dimension.get(i);//将词语和对应的向量值对应起来
}
// 初始化数据结构
kmeans_data data = new kmeans_data(ff, dataSet.size(), 200);
// 调用doKmeans方法进行聚类,参数列表:聚类数目,数据集,迭代次数,聚类中心变化阈值
kmeans.doKmeans(2, data, 4000,0.0);
// 输出聚类结果
for (int i = 0; i < dataSet.size(); i++) {
bw.write(dataSet.get(i));
bw.write(" ");
bw.write(String.valueOf(data.labels[i]));
bw.write("\r\n");
}
} catch (Exception ex) {
ex.printStackTrace();
} finally {
try {
bw.flush();
osw.flush();
fos.flush();
fos.close();
br.close();
isr.close();
fis.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
数据类(kmeans_data.java):
package com.zc.source;
public class kmeans_data {
public double[][] data;//存放词语和向量
public double[] dis;//每个样本和聚类中心的距离
public int length;//样本大小N
public int dim;//向量的维度
public int[] labels;//样本所属的类簇的标签
public double[][] centers;//存放聚类中心
public int[] centerCounts;//存放某个类簇中含有的样本的个数
public kmeans_data(double[][] data, int length, int dim) {
this.data = data;
this.length = length;
this.dim = dim;
}
}
k-means的处理类(kmeans.java):
package com.zc.source;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
public class kmeans {
/**
* double[][] 元素全置0
*
* @param matrix
* double[][]
* @param highDim
* int
* @param lowDim
* int <br/>
* double[highDim][lowDim]
*/
private static void setDouble2Zero(double[][] matrix, int highDim, int lowDim) {
for (int i = 0; i < highDim; i++) {
for (int j = 0; j < lowDim; j++) {
matrix[i][j] = 0;
}
}
}
/**
* 拷贝源二维矩阵元素到目标二维矩阵。 foreach (dests[highDim][lowDim] =
* sources[highDim][lowDim]);
*
* @param dests
* double[][]
* @param sources
* double[][]
* @param highDim
* int
* @param lowDim
* int
*/
private static void copyCenters(double[][] dests, double[][] sources, int highDim, int lowDim) {
for (int i = 0; i < highDim; i++) {
for (int j = 0; j < lowDim; j++) {
dests[i][j] = sources[i][j];
}
}
}
/**
* 更新聚类中心坐标
*
* @param k
* int 分类个数
* @param data
* kmeans_data
*/
private static void updateCenters(int k, kmeans_data data) {
double[][] centers = data.centers;
setDouble2Zero(centers, k, data.dim);
int[] labels = data.labels;
int[] centerCounts = data.centerCounts;
for (int i = 0; i < data.dim; i++) {
for (int j = 0; j < data.length; j++) {
centers[labels[j]][i] += data.data[j][i];
}
}
for (int i = 0; i < k; i++) {
for (int j = 0; j < data.dim; j++) {
centers[i][j] = centers[i][j] / centerCounts[i];
}
}
}
/**
* 计算两点余弦值
*
* @param pa
* double[]
* @param pb
* double[]
* @param dim
* int 维数
* @return double 距离
*/
public static double dist(double[] pa, double[] pb, int dim) {
double mpa = 0;// pa的莫
double mpb = 0;// pb的莫
double proab = 0;// pa和pb的向量积
for (int i = 0; i < dim; i++) {
proab = proab + pa[i] * pb[i];
mpa = mpa + pa[i] * pa[i];
mpb = mpb + pb[i] * pb[i];
}
double temp = 0;
temp = Math.sqrt(mpa) * Math.sqrt(mpb);
double result = proab / temp;
return result;
}
/**
* 计算两次聚类中心的欧式距离
*
* @param pa
* double[]
* @param pb
* double[]
* @param dim
* int 维数
* @return double 距离
*/
public static double distcen(double[] pa, double[] pb, int dim) {
double rv = 0;
for (int i = 0; i < dim; i++) {
double temp = pa[i] - pb[i];
temp = temp * temp;
rv += temp;
}
return Math.sqrt(rv);
}
/**
* 做Kmeans运算
*
* @param k
* int 聚类个数
* @param data
* kmeans_data kmeans数据类
* @param param
* kmeans_param kmeans参数类
* @return kmeans_result kmeans运行信息类
*/
public static void doKmeans(int k, kmeans_data data, int maxAttempts,double criteria) {
// 预处理
double[][] centers = new double[k][data.dim]; // 聚类中心点集
data.centers = centers;
int[] centerCounts = new int[k]; // 各聚类的包含点个数
data.centerCounts = centerCounts;
Arrays.fill(centerCounts, 0);
int[] labels = new int[data.length]; // 各个点所属聚类标号
data.labels = labels;
double[] dis = new double[data.length]; // 各个点于聚类中心的距离
data.dis = dis;
double[][] oldCenters = new double[k][data.dim]; // 临时缓存旧的聚类中心坐标
// 初始化聚类中心(随机选择data内的k个不重复点)
Random rn = new Random();
List<Integer> seeds = new LinkedList<Integer>();
while (seeds.size() < k) {
int randomInt = rn.nextInt(data.length);
if (!seeds.contains(randomInt)) {
seeds.add(randomInt);
}
}
Collections.sort(seeds);
for (int i = 0; i < k; i++) {
int m = seeds.remove(0);
for (int j = 0; j < data.dim; j++) {
centers[i][j] = data.data[m][j];
}
}
// 第一轮迭代
for (int i = 0; i < data.length; i++) {
double maxDist = dist(data.data[i], centers[0], data.dim);
int label = 0;
for (int j = 1; j < k; j++) {
double tempDist = dist(data.data[i], centers[j], data.dim);
if (tempDist > maxDist) {
maxDist = tempDist;
label = j;
}
}
dis[i] = maxDist;
labels[i] = label;
centerCounts[label]++;
}
updateCenters(k, data);//更新聚类中心
copyCenters(oldCenters, centers, k, data.dim);//赋值聚类中心
// 迭代预处理
int attempts = 1;
boolean[] flags = new boolean[k]; // 标记哪些中心被修改过
int it = 2;
// 迭代
iterate: while (attempts < maxAttempts) { // 迭代次数不超过最大值,最大中心改变量不超过阈值ֵ
for (int i = 0; i < k; i++) { // 初始化中心点“是否被修改过”标记
flags[i] = false;
}
for (int i = 0; i < data.length; i++) { // 遍历data内所有点
double maxDist = dist(data.data[i], centers[0], data.dim);
int label = 0;
for (int j = 1; j < k; j++) {
double tempDist = dist(data.data[i], centers[j], data.dim);
if (tempDist > maxDist) {
maxDist = tempDist;
label = j;
}
}
if (label != labels[i]) { // 如果当前点被聚类到新的类别则做更新
int oldLabel = labels[i];
labels[i] = label;
centerCounts[oldLabel]--;
centerCounts[label]++;
flags[oldLabel] = true;
flags[label] = true;
}
dis[i] = maxDist;
}
updateCenters(k, data);
attempts++;
// 得到被修改过的中心点最大修改量ֵ
double maxDist = 0;
for (int i = 0; i < k; i++) {
if (flags[i]) {
double tempDist = distcen(centers[i], oldCenters[i], data.dim);
if (maxDist < tempDist) {
maxDist = tempDist;
}
for (int j = 0; j < data.dim; j++) { // 更新oldCenter
oldCenters[i][j] = centers[i][j];
}
}
}
System.out.println("迭代第" + it + "次");
it++;
if (maxDist == criteria) {//查看被修改过的中心点最大修改量是否超过阈值
break iterate;
}
}
}
}
以上为全部的代码,我的数据格式为:word v1 v2 v3 ........具体的例子就是:中国 0.44 0.56 0.78 0.09 ...中间是用空格隔开的,因为在我的程序中我用的
样本是词语加上200维的向量,所以在我的主程序有些double数组中我直接写的是200,其他人若是用的时候需要根据具体情况去修改这个值。结果的格
式为:word label 例如:中国 1,说明中国这个词属于类簇1。