Kmeans算法java代码

package com.nju.yzf;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class Kmeans {
	
	/**
	 * @param args
	 * @throws IOException
	 */
	
	public static List<ArrayList<ArrayList<Double>>> 
	initHelpCenterList(List<ArrayList<ArrayList<Double>>> helpCenterList,int k){
		for(int i=0;i<k;i++){
			helpCenterList.add(new ArrayList<ArrayList<Double>>());
		}	
		return helpCenterList;
	}
	
	/**
	 * @param args
	 * @throws IOException
	 */
	public static void main(String[] args) throws IOException{
		
		List<ArrayList<Double>> centers = new ArrayList<ArrayList<Double>>();
		List<ArrayList<Double>> newCenters = new ArrayList<ArrayList<Double>>();
		List<ArrayList<ArrayList<Double>>> helpCenterList = new ArrayList<ArrayList<ArrayList<Double>>>();
		
		//读入原始数据
		BufferedReader br=new BufferedReader(new InputStreamReader(new FileInputStream("wine.txt")));
		String data = null;
		List<ArrayList<Double>> dataList = new ArrayList<ArrayList<Double>>();
		while((data=br.readLine())!=null){
			//System.out.println(data);
			String []fields = data.split(",");
			List<Double> tmpList = new ArrayList<Double>();
			for(int i=0; i<fields.length;i++)
				tmpList.add(Double.parseDouble(fields[i]));
			dataList.add((ArrayList<Double>) tmpList);
		}
		br.close();
		
		//随机确定K个初始聚类中心
		Random rd = new Random();
		int k=3;
		int [] initIndex={59,71,48};
		int [] helpIndex = {0,59,130};
		int [] givenIndex = {0,1,2};
		System.out.println("random centers' index");
		for(int i=0;i<k;i++){
			int index = rd.nextInt(initIndex[i]) + helpIndex[i];
			//int index = givenIndex[i];
			System.out.println("index "+index);
			centers.add(dataList.get(index));
			helpCenterList.add(new ArrayList<ArrayList<Double>>());
		}	
		
		/*
		//注释掉的这部分目的是,取测试数据集最后稳定的三个类簇的聚类中心作为初始聚类中心
		centers = new ArrayList<ArrayList<Double>>();
		for(int i=0;i<59;i++)
			helpCenterList.get(0).add(dataList.get(i));
		for(int i=59;i<130;i++)
			helpCenterList.get(1).add(dataList.get(i));
		for(int i=130;i<178;i++)
			helpCenterList.get(2).add(dataList.get(i));
		for(int i=0;i<k;i++){
			
			ArrayList<Double> tmp = new ArrayList<Double>();
			
			for(int j=0;j<dataList.get(0).size();j++){
				double sum=0;
				for(int t=0;t<helpCenterList.get(i).size();t++)
					sum+=helpCenterList.get(i).get(t).get(j);
				tmp.add(sum/helpCenterList.get(i).size());
			}
			centers.add(tmp);
		}
		*/
		
		//输出k个初始中心
		System.out.println("original centers:");
		for(int i=0;i<k;i++)
			System.out.println(centers.get(i));
		
		while(true)
		{//进行若干次迭代,直到聚类中心稳定
			
			for(int i=0;i<dataList.size();i++){//标注每一条记录所属于的中心
				double minDistance=99999999;
				int centerIndex=-1;
				for(int j=0;j<k;j++){//离0~k之间哪个中心最近
					double currentDistance=0;
					for(int t=1;t<centers.get(0).size();t++){//计算两点之间的欧式距离
						currentDistance	+=	((centers.get(j).get(t)-dataList.get(i).get(t))/(centers.get(j).get(t)+dataList.get(i).get(t))) * ((centers.get(j).get(t)-dataList.get(i).get(t))/(centers.get(j).get(t)+dataList.get(i).get(t))); 
					}
					if(minDistance>currentDistance){
						minDistance=currentDistance;
						centerIndex=j;
					}
				}
				helpCenterList.get(centerIndex).add(dataList.get(i));
			}
			
		//	System.out.println(helpCenterList);
			
			//计算新的k个聚类中心
			for(int i=0;i<k;i++){
				
				ArrayList<Double> tmp = new ArrayList<Double>();
				
				for(int j=0;j<centers.get(0).size();j++){
					double sum=0;
					for(int t=0;t<helpCenterList.get(i).size();t++)
						sum+=helpCenterList.get(i).get(t).get(j);
					tmp.add(sum/helpCenterList.get(i).size());
				}
				
				newCenters.add(tmp);
				
			}
			System.out.println("\nnew clusters' centers:\n");
			for(int i=0;i<k;i++)
				System.out.println(newCenters.get(i));
			//计算新旧中心之间的距离,当距离小于阈值时,聚类算法结束
			double distance=0;
			
			for(int i=0;i<k;i++){
				for(int j=1;j<centers.get(0).size();j++){//计算两点之间的欧式距离
					distance +=	((centers.get(i).get(j)-newCenters.get(i).get(j))/(centers.get(i).get(j)+newCenters.get(i).get(j))) * ((centers.get(i).get(j)-newC
  • 3
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 20
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 20
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值