(1)实验题目: K-means算法是经典的聚类算法,其基本思想是:以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。假设要把样本集分为K个类别,算法描述如下:
(1)适当选择k个类的初始中心
(2)在第I次迭代中,对任意一个样本,求其到K个中心的距离,将该样本归到距离最短的中心所在的类
(3)利用均值方式更新该类的中心值
(4)对于所有的K个聚类中心,如果利用(2)(3)的迭代法更新后,值保持基本不变,则迭代结束,否则继续迭代。
要求用java编写K-means算法(k值可以自己设定),根据花的属性对数据集Iris Data Set进行聚类,并将聚类结果(sepal length,sepal width,petal length,petal width ,cluster label)打印至cluster.txt文件。
iris数据包括四个属性:sepal length花萼长度,sepal width花萼宽度,petal length花瓣长度,petal width花瓣宽度。其中第五个值表示该样本属于哪一个类。Iris.data 可以用写字板打开。
注意:样本点间的距离直接用向量的欧氏距离。
(2)实验思路:
建立一个Kmeans类,它有以下变量和方法:
变量ArrayList<ArrayList<Double>> data存储从文件中读取的数据。用HashMap存储原来文件中的字符串对应的数字。变量k代表一共要分成k个类。
方法creatDataSet读入文件,建立数据集,其中前四列是原始文件前四列的数据,第五列是原始文件第五列字符串对应的数字。
方法getDistance计算两组数据的欧式距离;findCluster根据一组数据距离哪个类的中心最近找到它的所属类;getSSE计算所有数据的整体误差平方和;getMeans计算类的中心;
方法kMeans是核心方法,一开始随机选择中心,生成类。然后更新类的中心,再重新生成类,不断迭代直到误差平方和的变化小于0.2,得出分类结果。
方法printCluster将分类结果输出到cluster.txt中,前四列是数据,第五列是该组数据对应的类的编号。
方法calculateAccuracy计算分类的准确率。它的原理是分类结束后,一个类中出现次数最多的原始标签(指根据原始文件第五列赋予的标签)对应的数据计为分类正确,该类中属于其它标签的少数数据不应该处在这个类中,计为分类错误。将所有类中分类正确的数据量除以总数据量,即为分类准确率。
在主类中新建Kmeans类的实例,调用其方法完成实验。
(3)实验源码
package 实验;
import java.io.*;
import java.util.*;
class Kmeans
{
private ArrayList<ArrayList<Double>> data =new ArrayList<ArrayList<Double>>(); //存储数据
private HashMap<String,Double> mp= new HashMap<String, Double>(); //存储样本属于的类和对应的数字
int k; //一共几个簇
public void creatDataSet(String filename) throws IOException //读入文件,建立数据集
{
mp.put("Iris-setosa", 1.0);
mp.put("Iris-versicolor", 2.0);
mp.put("Iris-virginica", 3.0);
BufferedReader br = new BufferedReader(new FileReader(filename));
String string;
while((string=br.readLine())!=null)
{
String[] a=string.split(",");
data.add(new ArrayList<Double>());
int l=data.size();
for(int i=0;i<4;i++)
{
data.get(l-1).add(Double.parseDouble(a[i])); //放入前四个数值
}
data.get(l-1).add(mp.get(a[4])); //放入最后一个标签
}
br.close();
}
public double getDistance(ArrayList<Double> x,ArrayList<Double> y) //计算欧式距离
{
double distance=0;
for(int i=0;i<4;i++) distance+=Math.pow(x.get(i)-y.get(i), 2);
return Math.sqrt(distance);
}
public int findCluster(ArrayList<Double>[] means,ArrayList<Double> y) { //找到所属类
double mindistance=getDistance(y,means[0]);
double tempdistance;
int label=0;
for(int i=1;i<k;i++)
{
tempdistance=getDistance(y,means[i]);
if(tempdistance<mindistance)
{
mindistance=tempdistance;
label=i;
}
}
return label;
}
public double getSSE(ArrayList<ArrayList<Double>>[] cluster,ArrayList<Double>[] means) //误差平方和
{
double sse=0;
for(int i=0;i<k;i++)
{
for(int j=0;j<cluster[i].size();j++)
{
sse+=getDistance(cluster[i].get(j), means[i]);
}
}
return sse;
}
public ArrayList<Double> getMeans(ArrayList<ArrayList<Double>> cluster) //获取类的中心
{
int l=cluster.size();
ArrayList<Double> mean=new ArrayList<Double>(4);
for(int i=0;i<4;i++)mean.add(0.0);
for(int i=0;i<l;i++)
{
for(int j=0;j<4;j++) mean.set(j, mean.get(j)+cluster.get(i).get(j));
}
for(int i=0;i<4;i++)
{
mean.set(i, mean.get(i)/l);
}
return mean;
}
public void kMeans() //迭代更新类
{
ArrayList<ArrayList<Double>>[] cluster = new ArrayList[k]; //类
ArrayList<Double>[] means = new ArrayList[k]; //中心
Random random = new Random();
for(int i=0;i<k;i++) //随机寻找中心
{
means[i]=new ArrayList<Double>();
cluster[i]=new ArrayList<ArrayList<Double>>(); //分配空间
int randnum=random.nextInt(data.size()-1);
for(int j=0;j<4;j++) means[i].add(data.get(randnum).get(j));
}
for(int i=0;i<data.size();i++) //生成类
{
int label=findCluster(means, data.get(i));
cluster[label].add(data.get(i));
}
double sse=getSSE(cluster, means),pre_sse=0;
System.out.println("初始整体误差平方和:"+sse);
int iteration=0; //迭代次数
do {
pre_sse=sse;
for(int i=0;i<k;i++)
{
means[i]=getMeans(cluster[i]); //更新类的中心
}
for(int i=0;i<k;i++) //清空类
{
cluster[i].clear();
}
for(int i=0;i<data.size();i++) //重新生成类
{
int label=findCluster(means, data.get(i));
cluster[label].add(data.get(i));
}
sse=getSSE(cluster, means); //获取新的误差平方和
System.out.println("第"+(++iteration)+"次迭代后整体误差平方和为:"+sse);
} while (Math.abs(sse-pre_sse)>=0.2);
printCluster("G:\\code\\java\\实验\\cluster.txt",cluster); //将分类结果输出
calculateAccuracy(cluster);
}
public void calculateAccuracy(ArrayList<ArrayList<Double>>[] cluster) //计算准确率
{
int m=0;
for(int i=0;i<k;i++)
{
int a=0,b=0,c=0;
for(int j=0;j<cluster[i].size();j++)
{
switch (cluster[i].get(j).get(4).intValue()){
case 1: a++; break;
case 2: b++; break;
case 3: c++; break;
default:;
}
}
m+=Math.max(c, Math.max(a, b));
}
System.out.println("分类结果准确率:"+(double)m/150*100+"%");
}
public void printCluster(String filename,ArrayList<ArrayList<Double>>[] cluster)//输出cluster.txt
{
try {
File file = new File(filename);
file.createNewFile();
DataOutputStream ds=new DataOutputStream(new FileOutputStream(file));
for(int i=0;i<k;i++)
{
for(int j=0;j<cluster[i].size();j++)
{
for(int u=0;u<4;u++)
{
ds.writeBytes(""+cluster[i].get(j).get(u)); //输出一行数据
ds.writeBytes(",");
}
ds.writeBytes(i+1+"\r\n"); //输出类的编号并换行
}
}
ds.close();
} catch (Exception e) {
System.out.println(e.getMessage());
}
System.out.println("分类结果已输出到cluster.txt");
}
}
public class Text {
public static void main(String[] args) {
try {
Kmeans kmeans=new Kmeans();
kmeans.creatDataSet("G:\\code\\java\\实验\\iris.data");
kmeans.k=3;
kmeans.kMeans();
} catch (IOException e) {
System.out.println(e.getMessage());
}
}
}
(4)实验结果:
此次实验将k设置为3,多次迭代后整体误差平方和收敛,准确率为89.3%。
(5)实验心得:
此题设涉及的知识点有k-means算法的理解,文件输入输出,Arraylist的使用,分类准确率的计算。
k-means算法把距离作为相似性评价的指标,根据中心生成类,再根据类更新中心,不断迭代得到最后的分类结果。实验时一开始随机选择中心,生成类。然后更新类的中心,再重新生成类,不断迭代直到误差平方和的变化小于0.2,得出分类结果。
这里用BufferedReader读入数据,每次读入一行,将一行字符串根据逗号分割成字符串数组,前四个字符串转化为double,最后一个字符串根据HashMap转化为对应的数字存入。使用DataOutputStream将分类结果输出到cluster.txt文件。
因为读入数据之前不知道数据的数量,这里使用ArrayList<ArrayList<Double>>存储数据,可以动态改变大小。运算过程中使用了其get,set,add等方法。
最后根据原始文件对数据打上的标签来判断实验结果的准确率。它的原理是分类结束后,一个类中出现次数最多的原始标签对应的数据计为分类正确,该类中属于其它标签的少数数据不应该处在这个类中,计为分类错误。将所有类中分类正确的数据量除以总数据量,即为分类准确率。