java实现k-means算法

 

 
//使用Point类中的flag属性保存所属聚类中心的属性,第二种方式直接使用集合存储聚类中心及该中心的点集
import java.util.Random;
import java.util.Scanner;

class PPoint{
	public float x;
	public float y;
	public int flag = -1;
	
	public PPoint(){
		
	}
	public PPoint(float x,float y){
		this.x = x;
		this.y = y;
	}
}


public class Test {
	PPoint pc[] = null;
	PPoint pcore[] = null;
	PPoint pcoren[] = null;
	
	public void init(){
		Scanner sc = new Scanner(System.in);
		System.out.println("请输入生成随机点个数");
		int num = sc.nextInt();
		pc = new PPoint[num];
		//防止生成重复点
		float x0 = new Random().nextInt(10);
		float y0 = new Random().nextInt(10);
		pc[0] = new PPoint();
		pc[0].x = x0;
		pc[0].y = y0;
		for(int i=1;i<num;i++){
			int flag = 0;
			float x = new Random().nextInt(10);
			float y = new Random().nextInt(10);
			for(int j=0;j<i;j++){
				if(pc[j].x == x && pc[j].y == y){
					flag = 1;
					break;
				}
			}
			if(flag == 1){
				i--;
			}else{
				pc[i] = new PPoint();
				pc[i].x = x;
				pc[i].y = y;
			}
		}
		System.out.println("请输入聚类中心个数");
		int core = sc.nextInt();
		pcore = new PPoint[core];
		pcoren = new PPoint[core];
		//防止生成重复中心
		int temp[] = new int[core];
		temp[0] = new Random().nextInt(num);
		pcore[0] = new PPoint();
		pcore[0].x = pc[temp[0]].x;
		pcore[0].y = pc[temp[0]].y;
		for(int i=1;i<core;i++){
			int flag = 0;
			int tempRandom = new Random().nextInt(num);
			for(int j=0;j<i;j++){
				if(temp[j]==tempRandom){
					flag = 1;
					break;
				}
			}
			if(flag == 1){
				i--;
			}else{
				temp[i] = tempRandom;
				pcore[i] = new PPoint();
				pcore[i].x = pc[tempRandom].x;
				pcore[i].y = pc[tempRandom].y;
				pcore[i].flag = 0;	//0表示聚类中心
			}
		}
		
		System.out.println("生成随机点如下:");
		for(int i=0;i<num;i++){
			System.out.println(pc[i].x+","+pc[i].y);
		}
		System.out.println("生成聚类中心如下");
		for(int i=0;i<pcore.length;i++){
			System.out.println("<"+pcore[i].x+","+pcore[i].y+">");
		}
		
	}
	
	public void moveCore(){
		searchBelong();
		calAverage();
		double moveDist = 0;
		int flag = 0;
		for(int i=0;i<pcore.length;i++){
			flag = 0;
			moveDist = distPPoint(pcore[i], pcoren[i]);
			if(moveDist > 0.01){
				flag = 1;
				break;
			}
		}
		if(flag == 0){
			System.out.println("迭代完毕");
		}else{
			copyCore(pcore,pcoren);
			moveCore();
		}
	}
	
	public void copyCore(PPoint[] oldCore,PPoint[] newCore){
		for(int i=0;i<pcore.length;i++){
			oldCore[i].x = newCore[i].x;
			oldCore[i].y = newCore[i].y;
			oldCore[i].flag = 0;
		}
	}
	
	public void searchBelong(){
		for(int i=0;i<pc.length;i++){
			double dist = 999;
			int label = -1;
			for(int j=0;j<pcore.length;j++){
				double distance = distPPoint(pc[i],pcore[j]);
				if(distance < dist){
					dist = distance;
					label = j;
				}
			}
			pc[i].flag = label + 1;
		}
	}
	
	public double distPPoint(PPoint i,PPoint j){
		return Math.sqrt(Math.pow(i.x - j.x, 2) + Math.pow(i.y - j.y,2));
	}
	
	public void calAverage(){
		for(int i=0;i<pcore.length;i++){
			System.out.println("属于<"+pcore[i].x+","+pcore[i].y+">的点有:");
			float lengthX = 0;
			float lengthY = 0;
			int number = 0;
			for(int j=0;j<pc.length;j++){
				if(pc[j].flag == (i+1)){
					System.out.println(pc[j].x+","+pc[j].y);
					lengthX += pc[j].x;
					lengthY += pc[j].y;
					number++;
				}
			}
			pcoren[i] = new PPoint();
			pcoren[i].x = lengthX / number;
			pcoren[i].y = lengthY / number;
			pcoren[i].flag = 0;
			System.out.println("新的聚类中心为<"+pcoren[i].x+","+pcoren[i].y+">");
			
		}
	}
	
	public static void main(String[] args) {
		// TODO Auto-generated method stub
		Test test = new Test();
		test.init();
		test.moveCore();
	}

}
//这种方式使用集合存储聚类中心和点集,从文件中读取点集
package test;

import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Scanner;
import java.util.Set;

import org.apache.commons.io.IOUtils;

public class Test {

	private static final double maxDistance = 1.0e-9;
	private static List<Point> allPoint = new ArrayList<>();	//存储所有点集
	private static Map<Point, List<Point>> map = new HashMap<>();	//存储聚类中心和属于该类的点(与Point中用flag标识不同,这里使用Map结构存储中心里的点)
	private static Map<Point,Point> replaceOldPoint = new HashMap<>();	//key表示新的聚类中心,value表示旧的中心
	
	
	public static void main(String[] args) throws IOException {
		// TODO Auto-generated method stub
		Scanner sc = new Scanner(System.in);
		System.out.println("请输入聚类中心数目");
		int n = sc.nextInt();
		long startTime = System.currentTimeMillis();
		//读取文件中所有点放入allPoint
		readAllPoint();
		//首次随机指定聚类中心
		randomClusterCenter(n);
		//KMeans算法实现(当中心移动距离小于阈值break)
		for(int i=0;;i++){
			//判断每个点属于哪个中心,新的聚类中心保存在replaceOldPoint的key中
			for(Point p : allPoint){
				Point center = judge(p);
				if(map.get(center)==null){
					List pList = new ArrayList<>();
					pList.add(p);
					map.put(center, pList);
				}else{
					List pList = map.get(center);
					pList.add(p);
					map.put(center, pList);
				}
			}
			
			printResult(i+1);
			//从map中取得旧的聚类中心,并生成新的中心
			replaceOldPoint.clear();
			for(Point p : map.keySet()){
				double totalX = 0;
				double totalY = 0;
				List<Point> list = map.get(p);
				for(Point pp : list){
					totalX += pp.getX();
					totalY += pp.getY();
					
				}
				replaceOldPoint.put(new Point(totalX / list.size(),totalY / list.size()), p);
			}
			map.clear();
			if(maxMoveDistance() < maxDistance)
				break;
			
		}
		
		long endTime = System.currentTimeMillis();
		System.out.println("耗时:"+(endTime-startTime));
	}
	
	public static void readAllPoint() throws IOException{
		FileReader reader = new FileReader(new File("src/k-means_test.txt"));
		//利用commons-io.jar把文件内容存放到字符串中
		String s = IOUtils.toString(reader);
		String[] split = s.split(" ");
		for(String line : split){
			//去掉括号
			line = line.replaceAll("[\\(\\)]+", "");
			String[] fields = line.split(",");
			Point point = new Point();
			point.setX(Double.parseDouble(fields[0]));
			point.setY(Double.parseDouble(fields[1]));
			allPoint.add(point);
		}
		reader.close();
	}

	public static void randomClusterCenter(int n){
		List<Integer> list = new ArrayList<>();
		while(list.size() < n){
			int random;
			//确保不生成重复的聚类中心
			do{
				random =  (new Random().nextInt(allPoint.size()));
			}while(list.contains(random));
			list.add(random);
		}
		for(int number : list){
			Point old = new Point(Double.MAX_VALUE,Double.MAX_VALUE);
			replaceOldPoint.put(allPoint.get(number), old);
		}
	}
	
	public static Point judge(Point p){
		double dist = Double.MAX_VALUE;
		Point flagPoint = null;
		for(Point pp : replaceOldPoint.keySet()){
			double distance = getDistance(p,pp);
			if(distance < dist){
				dist = distance;
				flagPoint = pp;
			}
		}
		return flagPoint;
	}
	
	public static double getDistance(Point p,Point pp){
		return Math.sqrt(Math.pow(p.getX()-pp.getX(), 2) + Math.pow(p.getY()-pp.getY(), 2));
	}
	
	public static void printResult(int i){
		System.out.println("第"+i+"次聚类结果");
		String s = "聚类中心为:"+replaceOldPoint.keySet();
		System.out.println(s);
		for(Point p : replaceOldPoint.keySet()){
			System.out.println("属于中心<"+p.getX()+","+p.getY()+">的点:"+map.get(p));
			
		}
	}
	
	public static double maxMoveDistance(){
		double flagDistance = Double.MIN_VALUE;
		Set<Point> ps = replaceOldPoint.keySet();
		for(Point p : ps){
			flagDistance = getDistance(p, replaceOldPoint.get(p)) > flagDistance ? getDistance(p, replaceOldPoint.get(p)) : flagDistance;
		}
		return flagDistance;
	}
}
点集文件: 点击打开链接



  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值