kmeans算法java实现

package zqr.com;

import breeze.collection.mutable.ArrayLike;

import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;


public class KmeansFect {

private static int IntorNUm=20;
//private static Double BasicNUm=0.01;
//private static int ClustersNum=2;
private static List<Double[]> listDouble= new ArrayList();
private static List<Double[]> TowCenter=new ArrayList();
private static List<Double[]> oneC=new ArrayList();
private static List<Double[]> towC=new ArrayList();


public static void main(String[]args){


        List<Double[]> data = ReadFile();

        while(IntorNUm>0){
            System.out.println("============================================================"+(20-IntorNUm));
            ChooseCenter(data);

            DistanceCput(data, TowCenter);

            ReCuptCenter(oneC, towC);

            Boolean IsEnd = IsOverValue();
            IntorNUm--;
        }

}










    /**
     * 读取file变成double类型
     * @param
     * @return
     */
    public static List<Double[]> ReadFile(){

        // 绝对路径或相对路径都可以,这里是绝对路径,写入文件时演示相对路径
        String pathname = "/usr/local/spark/data/mllib/kmeans_data.txt";
        // 要读取以上路径的inputtxt文件
        File filename = new File(pathname);
        // 建立一个输入流对象reader
        InputStreamReader reader = null;
        try {

            reader = new InputStreamReader(  new FileInputStream(filename));

            // 建立一个对象,它把文件内容转成计算机能读懂的语言
            BufferedReader br = new BufferedReader(reader);
            String line = "o";

            try {
                while (line != null) {
                            // 一次读入一行数据
                            line = br.readLine();
                            // 创建一个double数组装数据
                            Double []linedataTreeth=new Double[3];
                            // 以空格符切分
                    if(line!=null) {

                        String[] arr = line.split(" ");


                            for(int i=0;i<arr.length;i++){

                                linedataTreeth[i]=Double.parseDouble(arr[i]);
                            }

                        // 打印file里的每一行
                        System.out.println(line);
                        // 把每一行转化成double类型然后添加到list                        listDouble.add(linedataTreeth);
                    }




                }
            } catch (IOException e) {
                e.printStackTrace();
            }


        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }




        return listDouble;

    }


    /**
     * 从读出来的数据集中随机选择k个聚类中心
     * @param data=listDouble
     * @return
     */
    public static List<Double[]>ChooseCenter(List<Double[]>data){
        // 获取到数据的length
        int length=data.size();

        System.out.println("获取长度为:"+length);
        // random函数随机,随机产生
        Random random = new Random();
        // 第一个质心
        int suiji1=random.nextInt(length);
        System.out.println("获取随机数1为:"+suiji1);
        Double[]center1=data.get(suiji1);
        // 第二个质心
        int suiji2=random.nextInt(length);
        System.out.println("获取随机数2为:"+suiji2);
        Double[]center2=data.get(suiji2);
        TowCenter.add(center1);
        TowCenter.add(center2);
        return TowCenter;
    }


    /**
     * 计算点到质心的距离,距离哪个质心近粗略划分到哪个类
     * @param data=listDouble
     * @param towCenter
     */
    public static void DistanceCput(List<Double[]>data,List<Double[]>towCenter){
        System.out.println("聚类中心1->"+towCenter.get(0)[0]+" "+towCenter.get(0)[1]+" "+towCenter.get(0)[2]);
        System.out.println("聚类中心2->"+towCenter.get(1)[0]+" "+towCenter.get(1)[1]+" "+towCenter.get(1)[2]);
        // 计算距离
        for(Double[]x:data){

            if(x[0]!=null) {
                //System.out.println(x[0]+" "+x[1]+" "+x[2]);
                double one = x[0] - towCenter.get(0)[0];
                double tow = x[1] - towCenter.get(0)[1];
                double three = x[2] - towCenter.get(0)[2];
                Double sum1 = one * one + tow * tow + three * three;
                double one2 = x[0] - towCenter.get(1)[0];
                double tow2 = x[1] - towCenter.get(1)[1];
                double three2 = x[2] - towCenter.get(1)[2];
                Double sum2 = one2 * one2 + tow2 * tow2 + three2 * three2;
                if (sum1 > sum2) {
                    oneC.add(x);
                } else if (sum1 < sum2) {
                    towC.add(x);
                } else {
                    oneC.add(x);
                    towC.add(x);
                }
            }
        }

    }


    /**
     * 计算平均x1+.......xn/n
     * 计算平均y1+.......yn/n
     * @param somepoint1 oneC
     * @param somepoint2 towC
     */
    public static void ReCuptCenter(List<Double[]>somepoint1,List<Double[]>somepoint2){
            double one=0.0;
            double tow=0.0;
            double three=0.0;
            int count=somepoint1.size();
            for(Double[]x:somepoint1){
                one=one+x[0];
                tow=tow+x[1];
                three=three+x[2];
            }
            // 此处可以添加一个阈值判断两个中心的距离,我这里就暂时不填加了
            TowCenter.get(0)[0]=one/count;
            TowCenter.get(0)[1]=tow/count;
            TowCenter.get(0)[2]=three/count;




            double one1=0.0;
            double tow1=0.0;
            double three1=0.0;
            int count1=somepoint2.size();
            for(Double[]x:somepoint2){
                one1=one1+x[0];
                tow1=tow1+x[1];
                three1=three1+x[2];
            }
            // 此处可以添加一个阈值判断两个中心的距离,我这里就暂时不填加了
            TowCenter.get(1)[0]=one1/count;
            TowCenter.get(1)[1]=tow1/count;
            TowCenter.get(1)[2]=three1/count;

    }


    /**
     * 计算找的聚类中心和平均中心的距离是不是小于最小偏移差
     * 或者说已经达到了最大迭代次数
     *  Center
     *  reCenter
     * @return
     */
    public static Boolean IsOverValue(){
        boolean flag=false;
        if(IntorNUm==0){
            flag=true;
        }else {

            oneC = new ArrayList();
            towC = new ArrayList();
        }
        return flag;
    }



}
阅读更多
换一批

没有更多推荐了,返回首页