算法——K均值聚类算法(Java实现)

7 篇文章 0 订阅

实现:Java代码如下

[java]  view plain copy
  1. package org.algorithm;  
  2.   
  3. import java.util.ArrayList;  
  4. import java.util.Random;  
  5.   
  6. /** 
  7.  * K均值聚类算法 
  8.  */  
  9. public class Kmeans {  
  10.     private int k;// 分成多少簇  
  11.     private int m;// 迭代次数  
  12.     private int dataSetLength;// 数据集元素个数,即数据集的长度  
  13.     private ArrayList<float[]> dataSet;// 数据集链表  
  14.     private ArrayList<float[]> center;// 中心链表  
  15.     private ArrayList<ArrayList<float[]>> cluster; // 簇  
  16.     private ArrayList<Float> jc;// 误差平方和,k越接近dataSetLength,误差越小  
  17.     private Random random;  
  18.   
  19.     /** 
  20.      * 设置需分组的原始数据集 
  21.      *  
  22.      * @param dataSet 
  23.      */  
  24.   
  25.     public void setDataSet(ArrayList<float[]> dataSet) {  
  26.         this.dataSet = dataSet;  
  27.     }  
  28.   
  29.     /** 
  30.      * 获取结果分组 
  31.      *  
  32.      * @return 结果集 
  33.      */  
  34.   
  35.     public ArrayList<ArrayList<float[]>> getCluster() {  
  36.         return cluster;  
  37.     }  
  38.   
  39.     /** 
  40.      * 构造函数,传入需要分成的簇数量 
  41.      *  
  42.      * @param k 
  43.      *            簇数量,若k<=0时,设置为1,若k大于数据源的长度时,置为数据源的长度 
  44.      */  
  45.     public Kmeans(int k) {  
  46.         if (k <= 0) {  
  47.             k = 1;  
  48.         }  
  49.         this.k = k;  
  50.     }  
  51.   
  52.     /** 
  53.      * 初始化 
  54.      */  
  55.     private void init() {  
  56.         m = 0;  
  57.         random = new Random();  
  58.         if (dataSet == null || dataSet.size() == 0) {  
  59.             initDataSet();  
  60.         }  
  61.         dataSetLength = dataSet.size();  
  62.         if (k > dataSetLength) {  
  63.             k = dataSetLength;  
  64.         }  
  65.         center = initCenters();  
  66.         cluster = initCluster();  
  67.         jc = new ArrayList<Float>();  
  68.     }  
  69.   
  70.     /** 
  71.      * 如果调用者未初始化数据集,则采用内部测试数据集 
  72.      */  
  73.     private void initDataSet() {  
  74.         dataSet = new ArrayList<float[]>();  
  75.         // 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0  
  76.         float[][] dataSetArray = new float[][] { { 82 }, { 34 }, { 25 },  
  77.                 { 42 }, { 73 }, { 62 }, { 47 }, { 63 }, { 53 },  
  78.                 { 63 }, { 69 }, { 16 }, { 39 }, { 41 }, { 86 } };  
  79.   
  80.         for (int i = 0; i < dataSetArray.length; i++) {  
  81.             dataSet.add(dataSetArray[i]);  
  82.         }  
  83.     }  
  84.   
  85.     /** 
  86.      * 初始化中心数据链表,分成多少簇就有多少个中心点 
  87.      *  
  88.      * @return 中心点集 
  89.      */  
  90.     private ArrayList<float[]> initCenters() {  
  91.         ArrayList<float[]> center = new ArrayList<float[]>();  
  92.         int[] randoms = new int[k];  
  93.         boolean flag;  
  94.         int temp = random.nextInt(dataSetLength);  
  95.         randoms[0] = temp;  
  96.         for (int i = 1; i < k; i++) {  
  97.             flag = true;  
  98.             while (flag) {  
  99.                 temp = random.nextInt(dataSetLength);  
  100.                 int j = 0;  
  101.                 // 不清楚for循环导致j无法加1  
  102.                 // for(j=0;j<i;++j)  
  103.                 // {  
  104.                 // if(temp==randoms[j]);  
  105.                 // {  
  106.                 // break;  
  107.                 // }  
  108.                 // }  
  109.                 while (j < i) {  
  110.                     if (temp == randoms[j]) {  
  111.                         break;  
  112.                     }  
  113.                     j++;  
  114.                 }  
  115.                 if (j == i) {  
  116.                     flag = false;  
  117.                 }  
  118.             }  
  119.             randoms[i] = temp;  
  120.         }  
  121.   
  122.         // 测试随机数生成情况  
  123.         // for(int i=0;i<k;i++)  
  124.         // {  
  125.         // System.out.println("test1:randoms["+i+"]="+randoms[i]);  
  126.         // }  
  127.   
  128.         // System.out.println();  
  129.         for (int i = 0; i < k; i++) {  
  130.             center.add(dataSet.get(randoms[i]));// 生成初始化中心链表  
  131.         }  
  132.         return center;  
  133.     }  
  134.   
  135.     /** 
  136.      * 初始化簇集合 
  137.      *  
  138.      * @return 一个分为k簇的空数据的簇集合 
  139.      */  
  140.     private ArrayList<ArrayList<float[]>> initCluster() {  
  141.         ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>();  
  142.         for (int i = 0; i < k; i++) {  
  143.             cluster.add(new ArrayList<float[]>());  
  144.         }  
  145.   
  146.         return cluster;  
  147.     }  
  148.   
  149.     /** 
  150.      * 计算两个点之间的距离 
  151.      *  
  152.      * @param element 
  153.      *            点1 
  154.      * @param center 
  155.      *            点2 
  156.      * @return 距离 
  157.      */  
  158.     private float distance(float[] element, float[] center) {  
  159.         float distance = 0.0f;  
  160.         float x = element[0] - center[0];  
  161.         float y = element[1] - center[1];  
  162.         float z = x * x + y * y;  
  163.         distance = (float) Math.sqrt(z);  
  164.   
  165.         return distance;  
  166.     }  
  167.   
  168.     /** 
  169.      * 获取距离集合中最小距离的位置 
  170.      *  
  171.      * @param distance 
  172.      *            距离数组 
  173.      * @return 最小距离在距离数组中的位置 
  174.      */  
  175.     private int minDistance(float[] distance) {  
  176.         float minDistance = distance[0];  
  177.         int minLocation = 0;  
  178.         for (int i = 1; i < distance.length; i++) {  
  179.             if (distance[i] < minDistance) {  
  180.                 minDistance = distance[i];  
  181.                 minLocation = i;  
  182.             } else if (distance[i] == minDistance) // 如果相等,随机返回一个位置  
  183.             {  
  184.                 if (random.nextInt(10) < 5) {  
  185.                     minLocation = i;  
  186.                 }  
  187.             }  
  188.         }  
  189.   
  190.         return minLocation;  
  191.     }  
  192.   
  193.     /** 
  194.      * 核心,将当前元素放到最小距离中心相关的簇中 
  195.      */  
  196.     private void clusterSet() {  
  197.         float[] distance = new float[k];  
  198.         for (int i = 0; i < dataSetLength; i++) {  
  199.             for (int j = 0; j < k; j++) {  
  200.                 distance[j] = distance(dataSet.get(i), center.get(j));  
  201.                 // System.out.println("test2:"+"dataSet["+i+"],center["+j+"],distance="+distance[j]);  
  202.   
  203.             }  
  204.             int minLocation = minDistance(distance);  
  205.             // System.out.println("test3:"+"dataSet["+i+"],minLocation="+minLocation);  
  206.             // System.out.println();  
  207.   
  208.             cluster.get(minLocation).add(dataSet.get(i));// 核心,将当前元素放到最小距离中心相关的簇中  
  209.   
  210.         }  
  211.     }  
  212.   
  213.     /** 
  214.      * 求两点误差平方的方法 
  215.      *  
  216.      * @param element 
  217.      *            点1 
  218.      * @param center 
  219.      *            点2 
  220.      * @return 误差平方 
  221.      */  
  222.     private float errorSquare(float[] element, float[] center) {  
  223.         float x = element[0] - center[0];  
  224.         float y = element[1] - center[1];  
  225.   
  226.         float errSquare = x * x + y * y;  
  227.   
  228.         return errSquare;  
  229.     }  
  230.   
  231.     /** 
  232.      * 计算误差平方和准则函数方法 
  233.      */  
  234.     private void countRule() {  
  235.         float jcF = 0;  
  236.         for (int i = 0; i < cluster.size(); i++) {  
  237.             for (int j = 0; j < cluster.get(i).size(); j++) {  
  238.                 jcF += errorSquare(cluster.get(i).get(j), center.get(i));  
  239.   
  240.             }  
  241.         }  
  242.         jc.add(jcF);  
  243.     }  
  244.   
  245.     /** 
  246.      * 设置新的簇中心方法 
  247.      */  
  248.     private void setNewCenter() {  
  249.         for (int i = 0; i < k; i++) {  
  250.             int n = cluster.get(i).size();  
  251.             if (n != 0) {  
  252.                 float[] newCenter = { 00 };  
  253.                 for (int j = 0; j < n; j++) {  
  254.                     newCenter[0] += cluster.get(i).get(j)[0];  
  255.                     newCenter[1] += cluster.get(i).get(j)[1];  
  256.                 }  
  257.                 // 设置一个平均值  
  258.                 newCenter[0] = newCenter[0] / n;  
  259.                 newCenter[1] = newCenter[1] / n;  
  260.                 center.set(i, newCenter);  
  261.             }  
  262.         }  
  263.     }  
  264.   
  265.     /** 
  266.      * 打印数据,测试用 
  267.      *  
  268.      * @param dataArray 
  269.      *            数据集 
  270.      * @param dataArrayName 
  271.      *            数据集名称 
  272.      */  
  273.     public void printDataArray(ArrayList<float[]> dataArray,  
  274.             String dataArrayName) {  
  275.         for (int i = 0; i < dataArray.size(); i++) {  
  276.             System.out.println("print:" + dataArrayName + "[" + i + "]={"  
  277.                     + dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");  
  278.         }  
  279.         System.out.println("===================================");  
  280.     }  
  281.   
  282.     /** 
  283.      * Kmeans算法核心过程方法 
  284.      */  
  285.     private void kmeans() {  
  286.         init();  
  287.         // printDataArray(dataSet,"initDataSet");  
  288.         // printDataArray(center,"initCenter");  
  289.   
  290.         // 循环分组,直到误差不变为止  
  291.         while (true) {  
  292.             clusterSet();  
  293.             // for(int i=0;i<cluster.size();i++)  
  294.             // {  
  295.             // printDataArray(cluster.get(i),"cluster["+i+"]");  
  296.             // }  
  297.   
  298.             countRule();  
  299.   
  300.             // System.out.println("count:"+"jc["+m+"]="+jc.get(m));  
  301.   
  302.             // System.out.println();  
  303.             // 误差不变了,分组完成  
  304.             if (m != 0) {  
  305.                 if (jc.get(m) - jc.get(m - 1) == 0) {  
  306.                     break;  
  307.                 }  
  308.             }  
  309.   
  310.             setNewCenter();  
  311.             // printDataArray(center,"newCenter");  
  312.             m++;  
  313.             cluster.clear();  
  314.             cluster = initCluster();  
  315.         }  
  316.   
  317.         // System.out.println("note:the times of repeat:m="+m);//输出迭代次数  
  318.     }  
  319.   
  320.     /** 
  321.      * 执行算法 
  322.      */  
  323.     public void execute() {  
  324.         long startTime = System.currentTimeMillis();  
  325.         System.out.println("kmeans begins");  
  326.         kmeans();  
  327.         long endTime = System.currentTimeMillis();  
  328.         System.out.println("kmeans running time=" + (endTime - startTime)  
  329.                 + "ms");  
  330.         System.out.println("kmeans ends");  
  331.         System.out.println();  
  332.     }  
  333. }  


4、说明:具体代码是从网上找的,根据自己的理解加了注释和进行部分修改,若注释有误还望指正

5、测试

[java]  view plain copy
  1. package org.test;  
  2.   
  3. import java.util.ArrayList;  
  4.   
  5. import org.algorithm.Kmeans;  
  6.   
  7. public class KmeansTest {  
  8.     public  static void main(String[] args)  
  9.     {  
  10.         //初始化一个Kmean对象,将k置为10  
  11.         Kmeans k=new Kmeans(10);  
  12.         ArrayList<float[]> dataSet=new ArrayList<float[]>();  
  13.           
  14.         dataSet.add(new float[]{1,2});  
  15.         dataSet.add(new float[]{3,3});  
  16.         dataSet.add(new float[]{3,4});  
  17.         dataSet.add(new float[]{5,6});  
  18.         dataSet.add(new float[]{8,9});  
  19.         dataSet.add(new float[]{4,5});  
  20.         dataSet.add(new float[]{6,4});  
  21.         dataSet.add(new float[]{3,9});  
  22.         dataSet.add(new float[]{5,9});  
  23.         dataSet.add(new float[]{4,2});  
  24.         dataSet.add(new float[]{1,9});  
  25.         dataSet.add(new float[]{7,8});  
  26.         //设置原始数据集  
  27.         k.setDataSet(dataSet);  
  28.         //执行算法  
  29.         k.execute();  
  30.         //得到聚类结果  
  31.         ArrayList<ArrayList<float[]>> cluster=k.getCluster();  
  32.         //查看结果  
  33.         for(int i=0;i<cluster.size();i++)  
  34.         {  
  35.             k.printDataArray(cluster.get(i), "cluster["+i+"]");  
  36.         }  
  37.           
  38.     }  
  39. }  


6、总结:测试代码已经通过。并对聚类的结果进行了查看,结果基本上符合要求。至于有没有更精确的算法有待发现。具体的实践还有待挖掘
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值