Java实验:K-means算法

1)实验题目: K-means算法是经典的聚类算法,其基本思想是:以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。假设要把样本集分为K个类别,算法描述如下:

(1)适当选择k个类的初始中心

(2)在第I次迭代中,对任意一个样本,求其到K个中心的距离,将该样本归到距离最短的中心所在的类

(3)利用均值方式更新该类的中心值

(4)对于所有的K个聚类中心,如果利用(2)(3)的迭代法更新后,值保持基本不变,则迭代结束,否则继续迭代。

要求用java编写K-means算法(k值可以自己设定),根据花的属性对数据集Iris Data Set进行聚类,并将聚类结果(sepal length,sepal width,petal length,petal width ,cluster label)打印至cluster.txt文件。

iris数据包括四个属性:sepal length花萼长度,sepal width花萼宽度,petal length花瓣长度,petal width花瓣宽度。其中第五个值表示该样本属于哪一个类。Iris.data 可以用写字板打开。

注意:样本点间的距离直接用向量的欧氏距离。

 

2)实验思路:

建立一个Kmeans类,它有以下变量和方法:

变量ArrayList<ArrayList<Double>> data存储从文件中读取的数据。用HashMap存储原来文件中的字符串对应的数字。变量k代表一共要分成k个类。

方法creatDataSet读入文件,建立数据集,其中前四列是原始文件前四列的数据,第五列是原始文件第五列字符串对应的数字。

方法getDistance计算两组数据的欧式距离;findCluster根据一组数据距离哪个类的中心最近找到它的所属类;getSSE计算所有数据的整体误差平方和;getMeans计算类的中心;

方法kMeans是核心方法,一开始随机选择中心,生成类。然后更新类的中心,再重新生成类,不断迭代直到误差平方和的变化小于0.2,得出分类结果。

方法printCluster将分类结果输出到cluster.txt中,前四列是数据,第五列是该组数据对应的类的编号。

方法calculateAccuracy计算分类的准确率。它的原理是分类结束后,一个类中出现次数最多的原始标签(指根据原始文件第五列赋予的标签)对应的数据计为分类正确,该类中属于其它标签的少数数据不应该处在这个类中,计为分类错误。将所有类中分类正确的数据量除以总数据量,即为分类准确率。

在主类中新建Kmeans类的实例,调用其方法完成实验。

 

3)实验源码

package 实验;

import java.io.*;
import java.util.*;

class Kmeans
{
	private ArrayList<ArrayList<Double>> data =new ArrayList<ArrayList<Double>>(); //存储数据
	private HashMap<String,Double> mp= new HashMap<String, Double>();  //存储样本属于的类和对应的数字
	int k;  //一共几个簇
	
	public void creatDataSet(String filename) throws IOException  //读入文件,建立数据集
	{
		mp.put("Iris-setosa", 1.0);
		mp.put("Iris-versicolor", 2.0);
		mp.put("Iris-virginica", 3.0);
		BufferedReader br = new BufferedReader(new FileReader(filename));
		String string;
		while((string=br.readLine())!=null)
		{
			String[] a=string.split(",");
			data.add(new ArrayList<Double>());
			int l=data.size();
			for(int i=0;i<4;i++)
			{
				data.get(l-1).add(Double.parseDouble(a[i]));  //放入前四个数值
			}
			data.get(l-1).add(mp.get(a[4]));  //放入最后一个标签
		}
		br.close();
	}
	
	public double getDistance(ArrayList<Double> x,ArrayList<Double> y) //计算欧式距离
	{
		double distance=0;
		for(int i=0;i<4;i++) distance+=Math.pow(x.get(i)-y.get(i), 2);
		return Math.sqrt(distance);
	}
	
	public int findCluster(ArrayList<Double>[] means,ArrayList<Double> y) { //找到所属类
		double mindistance=getDistance(y,means[0]);
		double tempdistance;
		int label=0;
		for(int i=1;i<k;i++)
		{
			tempdistance=getDistance(y,means[i]);
			if(tempdistance<mindistance) 
			{
				mindistance=tempdistance;
				label=i;
			}
		}
		return label;
	}
	
	public double getSSE(ArrayList<ArrayList<Double>>[] cluster,ArrayList<Double>[] means) //误差平方和
	{
		double sse=0;
		for(int i=0;i<k;i++)
		{
			for(int j=0;j<cluster[i].size();j++)
			{
				sse+=getDistance(cluster[i].get(j), means[i]);
			}
		}
		return sse;
	}
	
	public ArrayList<Double> getMeans(ArrayList<ArrayList<Double>> cluster) //获取类的中心
	{
		int l=cluster.size();
		 ArrayList<Double> mean=new ArrayList<Double>(4);
		 for(int i=0;i<4;i++)mean.add(0.0);
		 for(int i=0;i<l;i++)
		 {
			 for(int j=0;j<4;j++) mean.set(j, mean.get(j)+cluster.get(i).get(j));
		 }
		 for(int i=0;i<4;i++)
		 {
			 mean.set(i, mean.get(i)/l);
		 }
		 return mean;
	}
	
	public void kMeans()  //迭代更新类
	{
		ArrayList<ArrayList<Double>>[] cluster = new ArrayList[k]; //类
		ArrayList<Double>[] means = new ArrayList[k];            //中心
		Random random = new Random();
		for(int i=0;i<k;i++)  //随机寻找中心
		{
			means[i]=new ArrayList<Double>();
			cluster[i]=new ArrayList<ArrayList<Double>>(); //分配空间
			int randnum=random.nextInt(data.size()-1);
			for(int j=0;j<4;j++) means[i].add(data.get(randnum).get(j));
		}
		for(int i=0;i<data.size();i++)  //生成类
		{
			int label=findCluster(means, data.get(i));
			cluster[label].add(data.get(i));
		}
		double sse=getSSE(cluster, means),pre_sse=0;
		System.out.println("初始整体误差平方和:"+sse);
		int iteration=0;  //迭代次数
		do {
			pre_sse=sse;
			for(int i=0;i<k;i++)
			{
				means[i]=getMeans(cluster[i]); //更新类的中心
			}
			for(int i=0;i<k;i++)   //清空类
			{
				cluster[i].clear();
			}
			for(int i=0;i<data.size();i++) //重新生成类
			{
				int label=findCluster(means, data.get(i));
				cluster[label].add(data.get(i));
			}
			sse=getSSE(cluster, means); //获取新的误差平方和
			System.out.println("第"+(++iteration)+"次迭代后整体误差平方和为:"+sse);
		} while (Math.abs(sse-pre_sse)>=0.2);
		printCluster("G:\\code\\java\\实验\\cluster.txt",cluster); //将分类结果输出
		calculateAccuracy(cluster);
	}
	
	public void calculateAccuracy(ArrayList<ArrayList<Double>>[] cluster)  //计算准确率
	{ 
		int m=0;
		for(int i=0;i<k;i++)
		{
			int a=0,b=0,c=0;
			for(int j=0;j<cluster[i].size();j++)
			{
				switch (cluster[i].get(j).get(4).intValue()){
				case 1: a++; break;
				case 2: b++; break;
				case 3: c++; break;
				default:;
				}
			}
			m+=Math.max(c, Math.max(a, b));
		}
		System.out.println("分类结果准确率:"+(double)m/150*100+"%");
	}
	
	public void printCluster(String filename,ArrayList<ArrayList<Double>>[] cluster)//输出cluster.txt
	{
		try {
			File file = new File(filename);
			file.createNewFile();
			DataOutputStream ds=new DataOutputStream(new FileOutputStream(file));
			for(int i=0;i<k;i++)
			{
				for(int j=0;j<cluster[i].size();j++)
				{
					for(int u=0;u<4;u++)
					{
						ds.writeBytes(""+cluster[i].get(j).get(u)); //输出一行数据
						ds.writeBytes(",");
					}
					ds.writeBytes(i+1+"\r\n");  //输出类的编号并换行
				}
			}
			ds.close();
		} catch (Exception e) {
			System.out.println(e.getMessage());
		}
		System.out.println("分类结果已输出到cluster.txt");
	}
	
}

public class Text {
	
	public static void main(String[] args) {
		try {
			Kmeans kmeans=new Kmeans();
			kmeans.creatDataSet("G:\\code\\java\\实验\\iris.data");
			kmeans.k=3;
			kmeans.kMeans();
		} catch (IOException e) {
			System.out.println(e.getMessage());
		}
		
	}

}

 

4)实验结果:

此次实验将k设置为3,多次迭代后整体误差平方和收敛,准确率为89.3%。

 

5)实验心得:

此题设涉及的知识点有k-means算法的理解,文件输入输出,Arraylist的使用,分类准确率的计算。

k-means算法把距离作为相似性评价的指标,根据中心生成类,再根据类更新中心,不断迭代得到最后的分类结果。实验时一开始随机选择中心,生成类。然后更新类的中心,再重新生成类,不断迭代直到误差平方和的变化小于0.2,得出分类结果。

这里用BufferedReader读入数据,每次读入一行,将一行字符串根据逗号分割成字符串数组,前四个字符串转化为double,最后一个字符串根据HashMap转化为对应的数字存入。使用DataOutputStream将分类结果输出到cluster.txt文件。

因为读入数据之前不知道数据的数量,这里使用ArrayList<ArrayList<Double>>存储数据,可以动态改变大小。运算过程中使用了其get,set,add等方法。

最后根据原始文件对数据打上的标签来判断实验结果的准确率。它的原理是分类结束后,一个类中出现次数最多的原始标签对应的数据计为分类正确,该类中属于其它标签的少数数据不应该处在这个类中,计为分类错误。将所有类中分类正确的数据量除以总数据量,即为分类准确率。

 

  • 4
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
对于K-Means算法实验结果分析可以从以下几个方面考虑: 1. 聚类效果: 聚类效果是评估K-Means算法的一个重要指标。可以通过观察聚类结果的可视化图像来进行初步判断。聚类结果的可视化图像通常是将原始数据按照聚类结果进行颜色标记,以便更直观地观察聚类效果。如果聚类结果比较好,那么不同的聚类簇应该是相对分离的,同一簇内的数据应该比较相似。 2. 聚类数量: K-Means算法的一个重要参数是聚类数量k。不同的k值可能会导致不同的聚类效果。因此,在实验需要尝试不同的k值,来寻找最优的聚类数量。通常可以使用轮廓系数或者SSE来评估不同k值下的聚类效果,以便选择最优的聚类数量。 3. 初始质心选取K-Means算法的另一个重要参数是初始质心的选取。不同的初始质心可能会导致不同的聚类效果。因此,在实验需要尝试不同的初始质心选取方法,以获得更好的聚类效果。 4. 算法效率: K-Means算法的效率通常比较高,但是随着数据量的增加,算法的计算复杂度也会增加。因此,在实验需要对算法的计算时间进行评估,以便对算法进行优化。 总之,K-Means算法实验结果分析需要从多个方面进行考虑。通过综合考虑聚类效果、聚类数量、初始质心选取和算法效率等因素,可以得出更准确的结论,并对算法进行进一步优化。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值