K-Means算法的代码实现(Java)

 

K-Means算法的代码实现(Java)  

2013-03-02 20:28:41|  分类: 数据挖掘 |  标签: |举报 |字号大中小 订阅

//package cn.edu.pku.ss.dm.cluster;

import java.io.BufferedReader;

import java.io.BufferedWriter;

import java.io.FileNotFoundException;

import java.io.FileReader;

import java.io.FileWriter;

import java.io.IOException;

import java.util.ArrayList;

 

//K-means算法实现

 

public class KMeans {

    //聚类的数目

    final static int ClassCount = 3;

    //样本数目(测试集)

    final static int InstanceNumber = 150; 

    //样本属性数目(测试)

    final static int FieldCount = 5;

   

    //设置异常点阈值参数(每一类初始的最小数目为InstanceNumber/ClassCount^t

    final static double t = 2.0;

    //存放数据的矩阵

    private float[][] data;

   

    //每个类的均值中心

    private float[][] classData;

   

    //噪声集合索引

    private ArrayList<Integer> noises;

   

    //存放每次变换结果的矩阵

    private ArrayList<ArrayList<Integer>> result;

   

    //构造函数,初始化

    public KMeans()

    {

   //最后一位用来储存结果

   data = new float[InstanceNumber][FieldCount+1];

   classData = new float[ClassCount][FieldCount];

   result = new ArrayList<ArrayList<Integer>>(ClassCount);

   noises = new ArrayList<Integer>();

  

    }

   

 

   /**

    * 主函数入口

    * 测试集的文件名称为测试集.data”,其中有1000*57大小的数据

    * 每一行为一个样本,有57个属性

    * 主要分为两个步骤

    * 1.读取数据

    * 2.进行聚类

    * 最后统计运行时间和消耗的内存

    * @param args

    */

   public static void main(String[] args) {

      // TODO Auto-generated method stub

       long startTime = System.currentTimeMillis();

       KMeans cluster = new KMeans();

       //读取数据

       cluster.readData("D:/test.txt");

       //聚类过程

       cluster.cluster();

       //输出结果

       cluster.printResult("clusterResult.data");

       long endTime = System.currentTimeMillis();

       System.out.println("Total Time:"+ (endTime - startTime)/1000+"s");

       System.out.println("Memory Consuming:"+(float)(Runtime.getRuntime().totalMemory() -

          Runtime.getRuntime().freeMemory())/1000000 + "MB");

   }

        /*

         * 读取测试集的数据

         *

         * @param trainingFileName 测试集文件名

         */

   public void readData(String trainingFileName)

   {

       try

       {

      FileReader fr = new FileReader(trainingFileName);

      BufferedReader br = new BufferedReader(fr);

      //存放数据的临时变量

      String lineData = null;

      String[] splitData = null;

      int line = 0;

      //按行读取

      while(br.ready())

      {

          //得到原始的字符串

          lineData = br.readLine();

          splitData = lineData.split(",");

          //转化为数据

//        System.out.println("length:"+splitData.length);

          if(splitData.length>1)

          {

             for(int i = 0;i < splitData.length;i++)

             {

//              System.out.println(splitData[i]);

//              System.out.println(splitData[i].getClass());

                if(splitData[i].startsWith("Iris-setosa"))

                {

                   data[line][i] = (float) 1.0;

                }

                else if(splitData[i].startsWith("Iris-versicolor"))

                {

                   data[line][i] = (float) 2.0;

                }

                else if(splitData[i].startsWith("Iris-virginica"))

                {

                   data[line][i] = (float) 3.0;

                }

                else

                {   //将数据截取之后放进数组

                   data[line][i] = Float.parseFloat(splitData[i]);

                }

             }

             line++;

          }

      }

      System.out.println(line);

       }catch(IOException e)

       {

      e.printStackTrace();

       }

   }

   /*

    * 聚类过程,主要分为两步

    * 1.循环找初始点

    * 2.不断调整直到分类不再发生变化

    */

   public void cluster()

   {

       //数据归一化

       normalize();

       //标记是否需要重新找初始点

       boolean needUpdataInitials = true;

      

       //找初始点的迭代次数

       int times = 1;

       //找初始点

       while(needUpdataInitials)

       {

      needUpdataInitials = false;

      result.clear();

      System.out.println("Find Initials Iteration"+(times++)+"time(s)");

     

      //一次找初始点的尝试和根据初始点的分类

      findInitials();

      firstClassify();

     

      //如果某个分类的数目小于特定的阈值,则认为这个分类中的所有样本都是噪声点

      //需要重新找初始点

      for(int i = 0;i < result.size();i++)

      {

          if(result.get(i).size() < InstanceNumber/Math.pow(ClassCount,t))

          {

         needUpdataInitials = true;

         noises.addAll(result.get(i));

          }

      }

       }

      

       //找到合适的初始点后

       //不断的调整均值中心和分类,直到不再发生任何变化

       Adjust();

   }

  

   /*

    * 对数据进行归一化

    * 1.找每一个属性的最大值

    * 2.对某个样本的每个属性除以其最大值

    */

   public void normalize()

   {

       //找最大值

       float[] max = new float[FieldCount];

       for(int i = 0;i < InstanceNumber;i++)

       {

      for(int j = 0;j < FieldCount;j++)

      {

          if(data[i][j] > max[j])

         max[j] = data[i][j];

      }

       }

      

       //归一化

       for(int i = 0;i < InstanceNumber;i++)

       {

      for(int j = 0;j < FieldCount;j++)

      {

          data[i][j] = data[i][j]/max[j];

      }

       }

   }

  

   //关于初始向量的一次找寻尝试

   public void findInitials()

   {

       //a,b为标志距离最远的两个向量的索引

       int i,j,a,b;

       i = j = a = b = 0;

      

       //最远距离

       float maxDis = 0;

      

       //已经找到的初始点个数

       int alreadyCls = 2;

      

       //存放已经标记为初始点的向量索引

       ArrayList<Integer> initials = new ArrayList<Integer>();

      

       //从两个开始

       for(;i < InstanceNumber;i++)

       {

      //噪声点

      if(noises.contains(i))

          continue;

      //long startTime = System.currentTimeMillis();

      j = i + 1;

      for(;j < InstanceNumber;j++)

      {

          //噪声点

          if(noises.contains(j))

         continue;

          //找出最大的距离并记录下来

          float newDis = calDis(data[i],data[j]);

          if(maxDis < newDis)

          {

         a = i;

         b = j;

         maxDis = newDis;

          }

      }

      //long endTime = System.currentTimeMillis();

      //System.out.println(i + "Vector Caculation Time:"+(endTime-startTime)+"ms");

       }

      

       //将前两个初始点记录下来

       initials.add(a);

       initials.add(b);

       classData[0] = data[a];

       classData[1] = data[b];

      

       //在结果中新建存放某样本索引的对象,并把初始点添加进去

       ArrayList<Integer> resultOne = new ArrayList<Integer>();

       ArrayList<Integer> resultTwo = new ArrayList<Integer>();

       resultOne.add(a);

       resultTwo.add(b);

       result.add(resultOne);

       result.add(resultTwo);

      

       //找到剩余的几个初始点

       while(alreadyCls < ClassCount)

       {

      i = j = 0;

      float maxMin = 0;

      int newClass = -1;

     

      //找最小值中的最大值

      for(;i < InstanceNumber;i++)

      {

          float min = 0;

          float newMin = 0;

          //找和已有类的最小值

          if(initials.contains(i))

         continue;

          //噪声点去除

          if(noises.contains(i))

         continue;

          for(j = 0;j < alreadyCls;j++)

          {

         newMin = calDis(data[i],classData[j]);

         if(min == 0 || newMin < min)

             min = newMin;

          }

         

          //新最小距离较大

          if(min > maxMin)

          {

         maxMin = min;

         newClass = i;

          }

      }

      //添加到均值集合和结果集合中

      //System.out.println("NewClass"+newClass);

      initials.add(newClass);

      classData[alreadyCls++] = data[newClass];

      ArrayList<Integer> rslt = new ArrayList<Integer>();

      rslt.add(newClass);

      result.add(rslt);

       }

   }

  

   //第一次分类

   public void firstClassify()

   {

       //根据初始向量分类

       for(int i = 0;i < InstanceNumber;i++)

       {

      float min = 0f;

      int clsId = -1;

      for(int j = 0;j < classData.length;j++)

      {

          //欧式距离

          float newMin = calDis(classData[j],data[i]);

          if(clsId == -1 || newMin <min)

          {

         clsId = j;

         min = newMin;

          }

         

      }

      //本身不再添加

      if(!result.get(clsId).contains(i))

          result.get(clsId).add(i);

       }

   }

   //迭代分类,直到各个类的数据不再变化

   public void Adjust()

   {

       //记录是否发生变化

       boolean change = true;

      

       //循环的次数

       int times = 1;

       while(change)

       {

      //复位

      change = false;

      System.out.println("Adjust Iteration"+(times++)+"time(s)");

                   

      //重新计算每个类的均值 

      for(int i = 0;i < ClassCount; i++){ 

      //原有的数据 

      ArrayList<Integer> cls = result.get(i); 

       

      //新的均值 

      float[] newMean = new float[FieldCount ]; 

       

      //计算均值 

      for(Integer index:cls){ 

       for(int j = 0;j < FieldCount ;j++) 

              newMean[j] += data[index][j]; 

       } 

      for(int j = 0;j < FieldCount ;j++) 

         newMean[j] /= cls.size(); 

      if(!compareMean(newMean, classData[i])){ 

         classData[i] = newMean; 

           change = true

           } 

      } 

      //清空之前的数据 

      for(ArrayList<Integer> cls:result

       cls.clear(); 

        

      //重新分配 

      for(int i = 0;i < InstanceNumber;i++) 

      { 

       float min = 0f; 

       int clsId = -1; 

       for(int j = 0;j < classData.length;j++){ 

        float newMin = calDis(classData[j], data[i]); 

       if(clsId == -1 || newMin < min){ 

         clsId = j; 

           min = newMin; 

               } 

                 } 

                   data[i][FieldCount] = clsId; 

                    result.get(clsId).add(i); 

              } 

                 

         //测试聚类效果(训练集

      //          for(int i = 0;i < ClassCount;i++){ 

      //              int positives = 0; 

      //              int negatives = 0; 

      //              ArrayList<Integer> cls = result.get(i); 

      //              for(Integer instance:cls) 

      //                  if (data[instance][FieldCount - 1] == 1f) 

      //                      positives ++; 

      //                  else 

      //                      negatives ++; 

      //              System.out.println(" " + i + " Positive: " + positives + " Negatives: " + negatives); 

      //          } 

      //          System.out.println(); 

       }

               

               

   } 

          

         /**

           * 计算a样本和b样本的欧式距离作为不相似度

           * 

           * @param a     样本a

           * @param b     样本b

           * @return      欧式距离长度

           */ 

   private float calDis(float[] aVector,float[] bVector)  {

      double dis = 0;

      int i = 0;

               /*最后一个数据在训练集中为结果,所以不考虑  */

                for(;i < aVector.length;i++)

                     dis += Math.pow(bVector[i] - aVector[i],2); 

                dis = Math.pow(dis, 0.5); 

                return (float)dis; 

   }

         

        /**

         * 判断两个均值向量是否相等

         * 

         * @param a 向量a

              * @param b 向量b

         * @return

         */ 

       private boolean compareMean(float[] a,float[] b) 

       { 

             if(a.length != b.length

               return false

             for(int i =0;i < a.length;i++){ 

             if(a[i] > 0 &&b[i] > 0&& a[i] != b[i]){ 

                  return false

                }    

            } 

              return true

        } 

          

        /**

         * 将结果输出到一个文件中

         * 

              * @param fileName

              */ 

         public void printResult(String fileName) 

       { 

       FileWriter fw = null

            BufferedWriter bw = null

            try

                  fw = new FileWriter(fileName); 

               bw = new BufferedWriter(fw); 

              //写入文件 

               for(int i = 0;i < InstanceNumber;i++) 

               { 

                  bw.write(String.valueOf(data[i][FieldCount]).substring(0, 1)); 

                   bw.newLine(); 

                } 

               

               //统计每类的数目,打印到控制台 

               for(int i = 0;i < ClassCount;i++) 

             { 

                     System.out.println("" + (i+1) + "类数目: " + result.get(i).size()); 

              } 

         } catch (IOException e) { 

             e.printStackTrace(); 

           } finally

                 

               //关闭资源 

             if(bw != null

                   try

                     bw.close(); 

                   } catch (IOException e) { 

                       e.printStackTrace(); 

                  } 

               if(fw != null

                    try

                        fw.close(); 

                   } catch (IOException e) { 

                         e.printStackTrace(); 

                    } 

             } 

             

        } 

      } 


  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值