Java实现K-Means聚类算法

K-means算法基本思想

在数据集中根据一定策略选择K个点作为每个簇的初始中心,将数据划分到距离这K个点最近的簇中,共分成K个类。也就是说将数据划分成K个簇完成一次划分,但形成的新簇并不一定是最好的划分,因此生成的新簇中,重新计算每个簇的中心点,然后再重新进行划分,直到每次划分的结果保持不变。

算法步骤

  • 随机选择K个中心点
  • 把每个数据点分配到离它最近的中心点(此处的距离采用欧氏距离)
  • 重新计算每类中的点到该类中心点距离的平均值
  • 分配每个数据到它最近的中心点
  • 重复步骤3和4,直到每个类别中的数据不再发生变化。

Java实现K-means聚类算法

现有若干鸢尾花的数据,每朵鸢尾花有4个数据,分别为萼片长(单位:厘米)、萼片宽(单位厘米)、花瓣长度(单位厘米)和花瓣宽(单位厘米)。我们希望能找到可行的方法可以按每朵花的4个数据的差异将这些鸢尾花分成若干类,让每一类尽可能的准确,以便帮助植物专家对这些花进行进一步的分析。编程实现K-Means聚类算法,将鸢尾花分类成3类。

数据集样本如下:
在这里插入图片描述
先将以上数据写入文件,文件中的内容如下:

  • Iris.txt
1,5.4,3.4,1.5,0.4
2,5.2,4.1,1.5,0.1
3,5.5,4.2,1.4,0.2
4,4.9,3.1,1.5,0.2
5,5.0,3.2,1.2,0.2
6,5.5,3.5,1.3,0.2
7,4.9,3.6,1.4,0.1
8,4.4,3.0,1.3,0.2
9,5.1,3.4,1.5,0.2
10,5.0,3.5,1.3,0.3
11,4.5,2.3,1.3,0.3
12,4.4,3.2,1.3,0.2
13,5.0,3.5,1.6,0.6
14,5.1,3.8,1.9,0.4
15,4.8,3.0,1.4,0.3
16,5.1,3.8,1.6,0.2
17,4.6,3.2,1.4,0.2
18,5.3,3.7,1.5,0.2
19,5.0,3.3,1.4,0.2
20,7.0,3.2,4.7,1.4
21,6.4,3.2,4.5,1.5
22,6.9,3.1,4.9,1.5
23,5.5,2.3,4.0,1.3
24,6.5,2.8,4.6,1.5
25,5.7,2.8,4.5,1.3
26,6.3,3.3,4.7,1.6
27,4.9,2.4,3.3,1.0
28,6.6,2.9,4.6,1.3
29,5.2,2.7,3.9,1.4
30,5.0,2.0,3.5,1.0

Iris.txt 文件放在 Kmeans .java 同级目录下即可。

程序运行结果:
由于数据太长,只截取了一部分,如下:

在这里插入图片描述
在这里插入图片描述

从以上结果可以看到,第10次迭代后产生的分类结果和第9次完全相同,故分类完成,共迭代10次,算法结束。最后一次的迭代结果即为最终的分类结果。

代码如下

package main;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;

public class Kmeans {

	// 记录迭代的次数
	static int count = 1;
	// 文件所在路径
	static String filePath = System.getProperty("user.dir")+"\\src\\main\\Iris.txt";
	// 储存从文件中读取的数据
	static ArrayList<ArrayList<Float>> table = new ArrayList<ArrayList<Float>>();
	// 储存分类一的结果
	static ArrayList<ArrayList<Float>> alist = new ArrayList<ArrayList<Float>>();
	// 储存分类二的结果
	static ArrayList<ArrayList<Float>> blist = new ArrayList<ArrayList<Float>>();
	// 储存分类三的结果
	static ArrayList<ArrayList<Float>> clist = new ArrayList<ArrayList<Float>>();
	// 记录初始随机产生的3个聚类中心
	static ArrayList<ArrayList<Float>> randomList = new ArrayList<ArrayList<Float>>();
	
	// 读取文件中的数据,储存到集合中
	public static ArrayList<ArrayList<Float>> readTable(String filePath){
		ArrayList<Float> d = null;
		File file = new File(filePath);
		try {
			InputStreamReader isr = new InputStreamReader(new FileInputStream(file));
			BufferedReader bf = new BufferedReader(isr);
			String str = null;
			while((str = bf.readLine()) != null) {
				d = new ArrayList<Float>();
				String[] str1 = str.split(",");
				for(int i = 0; i < str1.length ; i++) {
					d.add(Float.parseFloat(str1[i]));
				}
				table.add(d);
			}
//			System.out.println(table);
			bf.close();
			isr.close();
		} catch (Exception e) {
			e.printStackTrace();
			System.out.println("文件不存在!");
		}
		return table;
	}
	
	// 随机产生3个初始聚类中心
	public static ArrayList<ArrayList<Float>> randomList() {
		int[] list = new int[3];
		//产生3个互不相同的随机数
		do {
			list[0] = (int)(Math.random()*30);
			list[1] = (int)(Math.random()*30);
			list[2] = (int)(Math.random()*30);
		}while(list[0] == list[1] && list[0] == list[2] && list[1] == list[2]);
//		System.out.println("索引:"+list[0]+" "+list[1]+" "+list[2]);
// 为了测试方便,我这里去数据集中前3个作为初始聚类中心
		for(int i = 0; i < 3 ; i++) {
			//randomList.add(list[i]);
			randomList.add(table.get(i));
		 }
		return randomList;
	}
	
	//比较两个数的大小,并返回其中较小的数
	public static double minNumber(double x, double y) {
		if(x < y) {
			return x;
		}
		return y;
	}
	
	// 计算各个数据到三个中心点的距离,然后分成三类
	public static void eudistance(ArrayList<ArrayList<Float>> list){
		alist.clear();
		blist.clear();
		clist.clear();
		double minNumber;
		double distancea,distanceb,distancec;
//		System.out.println("randomList:"+randomList);
		for(int i = 0; i < table.size() ; i++) {
			distancea = Math.pow(table.get(i).get(1)-list.get(0).get(1), 2) +
					Math.pow(table.get(i).get(2)-list.get(0).get(2), 2) + 
					Math.pow(table.get(i).get(3)-list.get(0).get(3), 2) + 
					Math.pow(table.get(i).get(4)-list.get(0).get(4), 2);
			distanceb = Math.pow(table.get(i).get(1)-list.get(1).get(1), 2) +
					Math.pow(table.get(i).get(2)-list.get(1).get(2), 2) +
					Math.pow(table.get(i).get(3)-list.get(1).get(3), 2) +
					Math.pow(table.get(i).get(4)-list.get(1).get(4), 2);
			distancec = Math.pow(table.get(i).get(1)-list.get(2).get(1), 2) +
					Math.pow(table.get(i).get(2)-list.get(2).get(2), 2) +
					Math.pow(table.get(i).get(3)-list.get(2).get(3), 2) +
					Math.pow(table.get(i).get(4)-list.get(2).get(4), 2);
			minNumber = minNumber(minNumber(distancea,distanceb),distancec);
			if(minNumber == distancea) {
				alist.add(table.get(i));
			}else if(minNumber == distanceb) {
				blist.add(table.get(i));
			}else {
				clist.add(table.get(i));
			}
		 }
		System.out.println("第"+count+"次迭代:");
		System.out.println(alist);
		System.out.println(blist);
		System.out.println(clist);
		System.out.println("\n");
		count++;
	}
	
	// 计算每个类中四个数据的平均值,重新生成聚类中心
	public static ArrayList<Float> newList(ArrayList<ArrayList<Float>> list){
		float avnum1,avnum2,avnum3,avnum4,c=0f;
		float sum1 = 0,sum2 = 0,sum3 = 0,sum4 = 0;
		ArrayList<Float> k = new ArrayList<Float>();
		for(int i = 0; i < list.size(); i++) {
			sum1 += list.get(i).get(1);
			sum2 += list.get(i).get(2);
			sum3 += list.get(i).get(3);
			sum4 += list.get(i).get(4);
		}
		avnum1 = (float)(sum1*1.0 / list.size());
		avnum2 = (float)(sum2*1.0 / list.size());
		avnum3 = (float)(sum3*1.0 / list.size());
		avnum4 = (float)(sum4*1.0 / list.size());
		k.add(c);
		k.add(avnum1);
		k.add(avnum2);
		k.add(avnum3);
		k.add(avnum4);
		return k;
	}
	
	// 判断两个集合的元素是否完全相同,若相同,则返回1;否则,返回0
	public static int same(ArrayList<ArrayList<Float>> list1, ArrayList<ArrayList<Float>> list2) {
		int countn = 0;
		if(list1.size()==list2.size()) {
			for(int i = 0; i < list1.size() ; i++) {
				for(int j = 0; j < list2.size() ; j++) {
					if(list1.get(i).containsAll(list2.get(j)) && list2.get(j).containsAll(list1.get(i))) {
						countn++;
						break;
					}
				}
			}
		}
		if(countn == list1.size()) {
			return 1;
		}else {
			return 0;
		}
	}
	
	// 迭代求出最后的分类结果
	public static void kmeans() {
		int a,b,c,k=0;
		ArrayList<ArrayList<Float>> klist = null;
		ArrayList<ArrayList<Float>> arlist = null;
		ArrayList<ArrayList<Float>> brlist = null;
		ArrayList<ArrayList<Float>> crlist = null;
		do {
			klist = new ArrayList<ArrayList<Float>>();
			arlist = new ArrayList<ArrayList<Float>>();
			brlist = new ArrayList<ArrayList<Float>>();
			crlist = new ArrayList<ArrayList<Float>>();
			arlist.addAll(alist);
			brlist.addAll(blist);
			crlist.addAll(clist);
			klist.clear();
			klist.add(newList(alist));
			klist.add(newList(blist));
			klist.add(newList(clist));
			eudistance(klist);
			a = same(alist,arlist);
			b = same(blist,brlist);
			c = same(clist,crlist);
			if(a == 1 && b == 1 && c == 1) {
				Kmeans.count = 1;
				break;
			}
		}while(true);
	}

	public static void main(String[] args) {
		ArrayList<ArrayList<Float>> rlist = new ArrayList<ArrayList<Float>>();
		readTable(filePath);
		rlist = randomList();
		eudistance(rlist);
		kmeans();
	}
}
  • 3
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值