kmeans算法的java实现

本文转载自http://www.cnblogs.com/zhangchaoyang/articles/2181869.html。并对其内容进行了补充和完善,使代码可以直接运行,运算的原始数据由随机数产生。

图示为3个簇,1000个二维变量的分类结果

主程序:

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
public class KMeans {
	int k;				//簇数
	double mu;			//迭代终止条件,当各个新质心相对于老质心偏移量小于mu时终止迭代
	double[][] center;	//上一次各簇质心的位置
	int repeat;			//存放每次运行的满意度
	double[] crita;		//存放满意度
	public KMeans(int k,double mu,int repeat,int len){
		this.k=k;
		this.mu=mu;
		this.repeat=repeat;
		center=new double[k][];
		for(int i=0;i<k;i++)
			center[i]=new double[len];
		crita=new double[repeat];
	}
	//初始化k个质心,每个质心是len维德向量
	public  void initCenter(int len,ArrayList<DataObject> objects){
		Random random=new Random(System.currentTimeMillis());
		int[] count=new int[k];
		Iterator<DataObject> iter=objects.iterator();
		while(iter.hasNext()){
			DataObject object=iter.next();
			int id=random.nextInt(100000)%k;
			count[id]++;
			for(int i=0;i<len;i++){
				center[id][i]+=object.getVector()[i];
				object.setCid(id);
			}
		}
		for(int i=0;i<k;i++)//计算质心
			for(int j=0;j<len;j++)
				center[i][j]/=count[i];	
	}
	//将各点进行归类
	public void classify(ArrayList<DataObject> objects){
		Iterator<DataObject> iter=objects.iterator();
		while(iter.hasNext()){
			DataObject object=iter.next();
			double[] vector=object.getVector();
			int len=vector.length;
			int index=0;
			double neardist=Double.MAX_VALUE;
			for(int i=0;i<k;i++){
				double dist=Global.calEuraDist(vector, center[i], len);
				if(dist<neardist){
					neardist=dist;
					index=i;
				}
			}
			object.setCid(index);
		}		
	}
	//重新计算质心
	public boolean calNewCenter(ArrayList<DataObject> objects,int len){
		boolean end=true;
		int[] count=new int[k];
		double[][] sum=new double[k][];
		for(int i=0;i<k;i++)
			sum[i]=new double[len];
		Iterator<DataObject> iter=objects.iterator();
		while(iter.hasNext()){
			DataObject object=iter.next();
			int id=object.getCid();
			count[id]++;
			for(int i=0;i<len;i++)
				sum[id][i]+=object.getVector()[i];
		}
		for(int i=0;i<k;i++){
			if(count[i]!=0){
				for(int j=0;j<len;j++)
					sum[i][j]/=count[i];
			}
			else{
				int a=(i+1)%k;
				int b=(i+3)%k;
				int c=(i+5)%k;
				for(int j=0;j<len;j++)
				center[i][j]=(center[a][j]+center[b][j]+center[c][j])/3;
			}
		}
		for(int i=0;i<this.k;i++){
			if(Global.calEuraDist(sum[i],center[i],len)>mu){
				end=false;
				break;
			}
		}
		if(end==false){
			for(int i=0;i<this.k;i++){
				for(int j=0;j<len;j++)
					center[i][j]=sum[i][j];
			}
		}
		return end;
	}
	//计算满意度
	public double getSati(ArrayList<DataObject> objects,int len){
		double satisfy=0;
		int[] count=new int[k];
		double[] ss=new double[k];
		Iterator<DataObject> iter=objects.iterator();
		while(iter.hasNext()){
			DataObject object=iter.next();
			int id=object.getCid();
			count[id]++;
			for(int i=0;i<len;i++)
				ss[id]+=Math.pow(object.getVector()[i]-center[id][i],2);
		}
		for(int i=0;i<k;i++)
			satisfy+=count[i]*ss[i];
		return satisfy;
	}
	//迭代过程
	public double run(int round,DataSource datasource,int len) throws IOException{
		System.out.println("第"+round+"次运行");
		File[] files=new File[this.k];

		initCenter(len,datasource.getObjects());
		classify(datasource.getObjects());
		while(!calNewCenter(datasource.getObjects(),len)){
			classify(datasource.getObjects());
		}
		ArrayList<DataObject> objects=datasource.getObjects();
		Iterator<DataObject> it=objects.iterator();
		for(int i=0;i<this.k;i++)
			files[i]=new File("out"+i+".txt");
		FileWriter[] filewriter=new FileWriter[this.k];
		for(int i=0;i<this.k;i++)
			filewriter[i]=new FileWriter(files[i]);
		while(it.hasNext()){
			DataObject tmp=it.next();
			int pos=tmp.getCid();
			for(int i=0;i<tmp.getVector().length;i++){
				filewriter[pos].write(Double.toString(tmp.getVector()[i]));
				filewriter[pos].write(" ");
			}
			filewriter[pos].write("\n");
		}
		for(int i=0;i<this.k;i++)
			filewriter[i].close();
		double ss=getSati(datasource.getObjects(),len);
		System.out.println("加权方差:"+ss);
		return ss;
	}
	public static void main(String[] args) throws IOException{
		int k=3;
		double mu= 1E-10;
		int repeat=10;
		int len=2;
		int node=1000;
		DataSource resource=new DataSource(node,len);
		KMeans km=new KMeans(k,mu,repeat,len);
		int index=0;
		double minsa=Double.MAX_VALUE;
		for(int i=0;i<km.repeat;i++){
			double ss=km.run(i,resource,len);
			if(ss<minsa){
				minsa=ss;
				index=i;
			}
		}
		System.out.println("最好的结果是第"+index+"次。");	
	}
}
DataSource
import java.io.*;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
import java.util.Scanner;
class DataSource {
	int node;			//number of nodes
	int len;			//dimensionality
	ArrayList<DataObject> objects;
	public DataSource(int node,int len) throws IOException{
		this.node=node;
		this.len=len;
		initData();
		objects=new ArrayList<DataObject>();	
		Scanner scanner=new Scanner(new FileInputStream("C:/Users/YJ/workspace/KMeans/o.txt"));
		for(int i=0;i<node;i++){
			double[] data=new double[this.len];
			for(int j=0;j<this.len;j++){
				data[j]=scanner.nextDouble();	
			}
			objects.add(new DataObject(data,len));	
		}
		scanner.close();
	}
	//assign data to objects
	public ArrayList<DataObject> getObjects(){
		return objects;
	}
		//initialize data
	void initData() throws IOException{
		Random random = new Random(System.currentTimeMillis());
		File out = new File("o.txt");
		FileWriter writer = new FileWriter(out);
		for(int i=0;i<node;i++){
			for(int j=0;j<len;j++){
				double t1=random.nextDouble()*1000;
				DecimalFormat fnum=new DecimalFormat("0.00");
				writer.write(fnum.format(t1));					
				writer.write(" ");
			}
			writer.write("\n");	
		}								
		writer.close();
	}
}

DataObject

class DataObject{
	double [] data;		//鍚勭淮鏁版嵁
	int len;
	int index;			//褰掑睘绗嚑涓皣
	public DataObject(double[] vector,int len){
		data=new double[len];
		System.arraycopy(vector, 0, data, 0, vector.length);
		this.len=len;
		index=0;
	}
	public double[] getVector(){
		return data;
	}
	public void setCid(int index){
		this.index=index;
	}
	public int getCid(){
		return index;
	}
}

Global
class Global {
	static public double calEuraDist(double[] vector,double[] center,int len){
		double sum=0;
		for(int i=0;i<len;i++){
			sum+=Math.pow(vector[i]-center[i],2);
		}
		sum=Math.sqrt(sum); 
		return sum;
	}
}

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值