K均值算法分析与实现(附源码)

K均值算法分析与实现

一、问题分析

题目要求对以下的十个点进行K均值聚类,
{x1(0,0),x2(3,8),x3(2,2),x4(1,1),x5(5,3),x6(4,8),x7(6,3),x8(5,4),x9(6,4),x10(7,5)}
首先,使用matlab绘出这十个点的散点图,如图所示。
在这里插入图片描述

二、方案实施

K均值聚类算法的原理为:
(1)任选K个模式特征矢量作为初始聚类中心:z1(1),z2(1),…zK(1)。括号内的序号表示迭代次数。
(2)将待分类的模式特征矢量集{x}中的模式逐个按最小距离原则分划给K类中的某一类。
如果Dj(k) =min{||x-zi(k)||},i=1,2,…,K,则判x∈Sj(k)
(3)计算重新分类后的各聚类中心zj(k+1),即求各聚类域中所包含样本的均值向量:
在这里插入图片描述
以均值向量作新的聚类中心,可得新的准则函数:
在这里插入图片描述
(4)如果zj(k+1)=zj(k)(j=1,2,…K),则结束;否则,k=k+1,转(2)
如图所示。
在这里插入图片描述
题目提供的数据集在本方案的操作下,一定迭代了四次。计算过程如下:
1、数据集初始化三个中心点分别为initCenter[0]={6.0,4.0},initCenter[1]={6.0,3.0},initCenter[2]={2.0,2.0},如图所示。
在这里插入图片描述
接着,根据最小距离原则,将数据集中的点归类到对应的距离最小的簇中去。如图所示。
在这里插入图片描述
2、根据公式计算新一轮迭代的簇的中心,分别(5.0,5.8)、(5.5,3.0)、(1.0,1.0),如图所示。
在这里插入图片描述接着,根据最小距离原则,将数据集中的点归类到对应的距离最小的簇中去。如图所示。
在这里插入图片描述
3、根据公式计算新一轮迭代的簇的中心,分别(4.6666665,7.0)、(5.5,3.5)、(1.0,1.0),如图所示。
在这里插入图片描述
接着,根据最小距离原则,将数据集中的点归类到对应的距离最小的簇中去。如图所示。
在这里插入图片描述
4、根据公式计算新一轮迭代的簇的中心,分别(3.5,8.0)、(5.8,3.8)、(1.0,1.0),如图所示。
在这里插入图片描述
接着,根据最小距离原则,将数据集中的点归类到对应的距离最小的簇中去。如图所示。
在这里插入图片描述
5、最后,通过第五次迭代检查误差是否不再变化,经过检查,第五次迭代的结果与第四次一样,误差不再发生变化,因此,此方案聚类计算的结果如图所示。
在这里插入图片描述

1Kmeans.java
package my;

import java.util.ArrayList;
import java.util.Random;
 
//K均值聚类算法
public class Kmeans {
	private int k;//簇的个数
	private int n;//簇的个数
	//数据集合的长度,即数据集中有多少个点
	private int dataSetNum;
	private ArrayList<float[]> dataSet;	//数据集链表
	private ArrayList<ArrayList<float[]>> Cluster; // 簇
	private ArrayList<float[]> Center;
	private ArrayList<Float> SSE;//距离平方和
	private Random random;
 
	
	//构造函数,传入聚类的簇的个数
	public Kmeans(int k) {
		if (k <= 0) {
			k = 1;
		}
		//如果传入的k小于等于0,设置为1
		this.k = k;
	}
	
	//设置原始聚类数据集
	public void setDataSet(ArrayList<float[]> dataSet) {
		this.dataSet = dataSet;
	}
 
    //return聚类结果
	public ArrayList<ArrayList<float[]>> getCluster() {
		return Cluster;
	}
 
	//初始化
	private void init() {
		n = 0;
		random = new Random();
		dataSetNum = dataSet.size();
		if (k > dataSetNum) {
			k = dataSetNum;
		}
		Center = initCenters();
		Cluster = initCluster();
		SSE = new ArrayList<Float>();
	}


	//初始化聚类中心数据链表
	private ArrayList<float[]> initCenters() {
		ArrayList<float[]> center = new ArrayList<float[]>();
		//中心点的个数和簇的个数一样
		int[] randoms = new int[k];
		boolean flag;
		int temp = random.nextInt(dataSetNum);
		randoms[0] = temp;
		for (int i = 1; i < k; i++) {
			flag = true;
			while (flag) {
				temp = random.nextInt(dataSetNum);
				int j = 0;
				while (j < i) {
					if (temp == randoms[j]) {
						break;
					}
					j++;
				}
				if (j == i) {
					flag = false;
				}
			}
			randoms[i] = temp;
		}
 
		for (int i = 0; i < k; i++) {
			center.add(dataSet.get(randoms[i]));
		}
		return center;
	}
 
	//初始化簇,返回一个有k个簇的数据集
	private ArrayList<ArrayList<float[]>> initCluster() {
		ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>();
		for (int i = 0; i < k; i++) {
			cluster.add(new ArrayList<float[]>());
		}
 
		return cluster;
	}
 
	//计算数据点和中心点距离
	private float distance(float[] element, float[] center) {
		float distance = 0.0f;
		float x = element[0] - center[0];
		float y = element[1] - center[1];
		float z = x * x + y * y;
		distance = (float) Math.sqrt(z);
 
		return distance;
	}
 
	//获取距离集合中最小距离的位置,返回最小距离在距离数组中的位置
	private int minDistance(float[] distance) {
		float minDistance = distance[0];
		int minLocation = 0;
		for (int i = 1; i < distance.length; i++) {
			if (distance[i] < minDistance) {
				minDistance = distance[i];
				minLocation = i;
			} 
			else if (distance[i] == minDistance) 
			{
				if (random.nextInt(10) < 5) {
					// 如果相等,随机返回一个位置
					minLocation = i;
				}
			}
		}
 
		return minLocation;
	}
 

	//将当前元素放到最小距离中心相关的簇中
	private void clusterSet() {
		float[] distance = new float[k];
		for (int i = 0; i < dataSetNum; i++) {
			for (int j = 0; j < k; j++) {
				distance[j] = distance(dataSet.get(i), Center.get(j));
				// System.out.println("test2:"+"dataSet["+i+"],center["+j+"],distance="+distance[j]);
 
			}
			int minLocation = minDistance(distance);
			// System.out.println("test3:"+"dataSet["+i+"],minLocation="+minLocation);
			// System.out.println();
			//将当前元素放到最小距离中心相关的簇中
			Cluster.get(minLocation).add(dataSet.get(i));
 
		}
	}
 
	
	//求两点误差平方的方法
	private float errorSquare(float[] element, float[] center) {
		float x = element[0] - center[0];
		float y = element[1] - center[1];
 
		float errSquare = x * x + y * y;
 
		return errSquare;
	}
 
	//计算误差平方和准则函数方法
	private void countRule() {
		float jcF = 0;
		for (int i = 0; i < Cluster.size(); i++) {
			for (int j = 0; j < Cluster.get(i).size(); j++) {
				jcF += errorSquare(Cluster.get(i).get(j), Center.get(i));
 
			}
		}
		SSE.add(jcF);
	}
 
	//设置新的簇中心方法
	private void setNewCenter() {
		for (int i = 0; i < k; i++) {
			int n = Cluster.get(i).size();
			if (n != 0) {
				float[] newCenter = { 0, 0 };
				for (int j = 0; j < n; j++) {
					newCenter[0] += Cluster.get(i).get(j)[0];
					newCenter[1] += Cluster.get(i).get(j)[1];
				}
				// 设置平均值
				newCenter[0] = newCenter[0] / n;
				newCenter[1] = newCenter[1] / n;
				Center.set(i, newCenter);
			}
		}
	}
 
	//打印数据集
	public void printDataArray(ArrayList<float[]> dataArray,
			String dataArrayName) {
		for (int i = 0; i < dataArray.size(); i++) {
			System.out.println("print:" + dataArrayName + "[" + i + "]={"
					+ dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");
		}
		System.out.println("————————————————————————————————");
	}
 
	
	private void kmeans() {
		init();
		
		printDataArray(dataSet,"initDataSet");
		printDataArray(Center,"initCenter");
 
		// 循环分组,直到误差不变为止
		while (true) {
			clusterSet();
			for(int i=0;i<Cluster.size();i++)
			{
			 printDataArray(Cluster.get(i),"cluster["+i+"]");
			}
 
			countRule();
 
			System.out.println("count:"+"jc["+n+"]="+SSE.get(n));
 
			System.out.println();
			// 误差不变了,分组完成
			if (n != 0) {
				if (SSE.get(n) - SSE.get(n - 1) == 0) {
					break;
				}
			}
 
			setNewCenter();
			printDataArray(Center,"newCenter");
			n++;
			Cluster.clear();
			Cluster = initCluster();
		}
		System.out.println("note:the times of repeat:n="+n);//输出迭代次数
	}
 

	public void execute() {
		long startTime = System.currentTimeMillis();
		System.out.println("kmeans begins");
		kmeans();
		long endTime = System.currentTimeMillis();
		System.out.println("kmeans running time=" + (endTime - startTime)
				+ "ms");
		System.out.println("kmeans ends");
		System.out.println();
	}
}




2Test.java
package my;
import java.util.ArrayList;
import my.Kmeans;
 
public class Test {
	public  static void main(String[] args)
	{
		//初始化一个Kmean对象,将k置为3
		Kmeans k=new Kmeans(3);
		ArrayList<float[]> dataSet=new ArrayList<float[]>();
		
		dataSet.add(new float[]{0,0});
		dataSet.add(new float[]{3,8});
		dataSet.add(new float[]{2,2});
		dataSet.add(new float[]{1,1});
		dataSet.add(new float[]{5,3});
		dataSet.add(new float[]{4,8});
		dataSet.add(new float[]{6,3});
		dataSet.add(new float[]{5,4});
		dataSet.add(new float[]{6,4});
		dataSet.add(new float[]{7,5});
		//设置原始数据集
		k.setDataSet(dataSet);
		//执行算法
		k.execute();
		//得到聚类结果
		ArrayList<ArrayList<float[]>> cluster=k.getCluster();
		//查看结果
		for(int i=0;i<cluster.size();i++)
		{
			k.printDataArray(cluster.get(i), "cluster["+i+"]");
		}
		
	}
}
  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值