K均值算法分析与实现
一、问题分析
题目要求对以下的十个点进行K均值聚类,
{x1(0,0),x2(3,8),x3(2,2),x4(1,1),x5(5,3),x6(4,8),x7(6,3),x8(5,4),x9(6,4),x10(7,5)}
首先,使用matlab绘出这十个点的散点图,如图所示。
二、方案实施
K均值聚类算法的原理为:
(1)任选K个模式特征矢量作为初始聚类中心:z1(1),z2(1),…zK(1)。括号内的序号表示迭代次数。
(2)将待分类的模式特征矢量集{x}中的模式逐个按最小距离原则分划给K类中的某一类。
如果Dj(k) =min{||x-zi(k)||},i=1,2,…,K,则判x∈Sj(k)
(3)计算重新分类后的各聚类中心zj(k+1),即求各聚类域中所包含样本的均值向量:
以均值向量作新的聚类中心,可得新的准则函数:
(4)如果zj(k+1)=zj(k)(j=1,2,…K),则结束;否则,k=k+1,转(2)
如图所示。
题目提供的数据集在本方案的操作下,一定迭代了四次。计算过程如下:
1、数据集初始化三个中心点分别为initCenter[0]={6.0,4.0},initCenter[1]={6.0,3.0},initCenter[2]={2.0,2.0},如图所示。
接着,根据最小距离原则,将数据集中的点归类到对应的距离最小的簇中去。如图所示。
2、根据公式计算新一轮迭代的簇的中心,分别(5.0,5.8)、(5.5,3.0)、(1.0,1.0),如图所示。
接着,根据最小距离原则,将数据集中的点归类到对应的距离最小的簇中去。如图所示。
3、根据公式计算新一轮迭代的簇的中心,分别(4.6666665,7.0)、(5.5,3.5)、(1.0,1.0),如图所示。
接着,根据最小距离原则,将数据集中的点归类到对应的距离最小的簇中去。如图所示。
4、根据公式计算新一轮迭代的簇的中心,分别(3.5,8.0)、(5.8,3.8)、(1.0,1.0),如图所示。
接着,根据最小距离原则,将数据集中的点归类到对应的距离最小的簇中去。如图所示。
5、最后,通过第五次迭代检查误差是否不再变化,经过检查,第五次迭代的结果与第四次一样,误差不再发生变化,因此,此方案聚类计算的结果如图所示。
1、Kmeans.java
package my;
import java.util.ArrayList;
import java.util.Random;
//K均值聚类算法
public class Kmeans {
private int k;//簇的个数
private int n;//簇的个数
//数据集合的长度,即数据集中有多少个点
private int dataSetNum;
private ArrayList<float[]> dataSet; //数据集链表
private ArrayList<ArrayList<float[]>> Cluster; // 簇
private ArrayList<float[]> Center;
private ArrayList<Float> SSE;//距离平方和
private Random random;
//构造函数,传入聚类的簇的个数
public Kmeans(int k) {
if (k <= 0) {
k = 1;
}
//如果传入的k小于等于0,设置为1
this.k = k;
}
//设置原始聚类数据集
public void setDataSet(ArrayList<float[]> dataSet) {
this.dataSet = dataSet;
}
//return聚类结果
public ArrayList<ArrayList<float[]>> getCluster() {
return Cluster;
}
//初始化
private void init() {
n = 0;
random = new Random();
dataSetNum = dataSet.size();
if (k > dataSetNum) {
k = dataSetNum;
}
Center = initCenters();
Cluster = initCluster();
SSE = new ArrayList<Float>();
}
//初始化聚类中心数据链表
private ArrayList<float[]> initCenters() {
ArrayList<float[]> center = new ArrayList<float[]>();
//中心点的个数和簇的个数一样
int[] randoms = new int[k];
boolean flag;
int temp = random.nextInt(dataSetNum);
randoms[0] = temp;
for (int i = 1; i < k; i++) {
flag = true;
while (flag) {
temp = random.nextInt(dataSetNum);
int j = 0;
while (j < i) {
if (temp == randoms[j]) {
break;
}
j++;
}
if (j == i) {
flag = false;
}
}
randoms[i] = temp;
}
for (int i = 0; i < k; i++) {
center.add(dataSet.get(randoms[i]));
}
return center;
}
//初始化簇,返回一个有k个簇的数据集
private ArrayList<ArrayList<float[]>> initCluster() {
ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>();
for (int i = 0; i < k; i++) {
cluster.add(new ArrayList<float[]>());
}
return cluster;
}
//计算数据点和中心点距离
private float distance(float[] element, float[] center) {
float distance = 0.0f;
float x = element[0] - center[0];
float y = element[1] - center[1];
float z = x * x + y * y;
distance = (float) Math.sqrt(z);
return distance;
}
//获取距离集合中最小距离的位置,返回最小距离在距离数组中的位置
private int minDistance(float[] distance) {
float minDistance = distance[0];
int minLocation = 0;
for (int i = 1; i < distance.length; i++) {
if (distance[i] < minDistance) {
minDistance = distance[i];
minLocation = i;
}
else if (distance[i] == minDistance)
{
if (random.nextInt(10) < 5) {
// 如果相等,随机返回一个位置
minLocation = i;
}
}
}
return minLocation;
}
//将当前元素放到最小距离中心相关的簇中
private void clusterSet() {
float[] distance = new float[k];
for (int i = 0; i < dataSetNum; i++) {
for (int j = 0; j < k; j++) {
distance[j] = distance(dataSet.get(i), Center.get(j));
// System.out.println("test2:"+"dataSet["+i+"],center["+j+"],distance="+distance[j]);
}
int minLocation = minDistance(distance);
// System.out.println("test3:"+"dataSet["+i+"],minLocation="+minLocation);
// System.out.println();
//将当前元素放到最小距离中心相关的簇中
Cluster.get(minLocation).add(dataSet.get(i));
}
}
//求两点误差平方的方法
private float errorSquare(float[] element, float[] center) {
float x = element[0] - center[0];
float y = element[1] - center[1];
float errSquare = x * x + y * y;
return errSquare;
}
//计算误差平方和准则函数方法
private void countRule() {
float jcF = 0;
for (int i = 0; i < Cluster.size(); i++) {
for (int j = 0; j < Cluster.get(i).size(); j++) {
jcF += errorSquare(Cluster.get(i).get(j), Center.get(i));
}
}
SSE.add(jcF);
}
//设置新的簇中心方法
private void setNewCenter() {
for (int i = 0; i < k; i++) {
int n = Cluster.get(i).size();
if (n != 0) {
float[] newCenter = { 0, 0 };
for (int j = 0; j < n; j++) {
newCenter[0] += Cluster.get(i).get(j)[0];
newCenter[1] += Cluster.get(i).get(j)[1];
}
// 设置平均值
newCenter[0] = newCenter[0] / n;
newCenter[1] = newCenter[1] / n;
Center.set(i, newCenter);
}
}
}
//打印数据集
public void printDataArray(ArrayList<float[]> dataArray,
String dataArrayName) {
for (int i = 0; i < dataArray.size(); i++) {
System.out.println("print:" + dataArrayName + "[" + i + "]={"
+ dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");
}
System.out.println("————————————————————————————————");
}
private void kmeans() {
init();
printDataArray(dataSet,"initDataSet");
printDataArray(Center,"initCenter");
// 循环分组,直到误差不变为止
while (true) {
clusterSet();
for(int i=0;i<Cluster.size();i++)
{
printDataArray(Cluster.get(i),"cluster["+i+"]");
}
countRule();
System.out.println("count:"+"jc["+n+"]="+SSE.get(n));
System.out.println();
// 误差不变了,分组完成
if (n != 0) {
if (SSE.get(n) - SSE.get(n - 1) == 0) {
break;
}
}
setNewCenter();
printDataArray(Center,"newCenter");
n++;
Cluster.clear();
Cluster = initCluster();
}
System.out.println("note:the times of repeat:n="+n);//输出迭代次数
}
public void execute() {
long startTime = System.currentTimeMillis();
System.out.println("kmeans begins");
kmeans();
long endTime = System.currentTimeMillis();
System.out.println("kmeans running time=" + (endTime - startTime)
+ "ms");
System.out.println("kmeans ends");
System.out.println();
}
}
2、Test.java
package my;
import java.util.ArrayList;
import my.Kmeans;
public class Test {
public static void main(String[] args)
{
//初始化一个Kmean对象,将k置为3
Kmeans k=new Kmeans(3);
ArrayList<float[]> dataSet=new ArrayList<float[]>();
dataSet.add(new float[]{0,0});
dataSet.add(new float[]{3,8});
dataSet.add(new float[]{2,2});
dataSet.add(new float[]{1,1});
dataSet.add(new float[]{5,3});
dataSet.add(new float[]{4,8});
dataSet.add(new float[]{6,3});
dataSet.add(new float[]{5,4});
dataSet.add(new float[]{6,4});
dataSet.add(new float[]{7,5});
//设置原始数据集
k.setDataSet(dataSet);
//执行算法
k.execute();
//得到聚类结果
ArrayList<ArrayList<float[]>> cluster=k.getCluster();
//查看结果
for(int i=0;i<cluster.size();i++)
{
k.printDataArray(cluster.get(i), "cluster["+i+"]");
}
}
}