BIRCH算法---使用聚类特征树的多阶段算法

更多数据挖掘代码:https://github.com/linyiqun/DataMiningAlgorithm

介绍

BIRCH算法本身上属于一种聚类算法,不过他克服了一些K-Means算法的缺点,比如说这个k的确定,因为这个算法事先本身就没有设定有多少个聚类。他是通过CF-Tree,(ClusterFeature-Tree)聚类特征树实现的。BIRCH的一个重要考虑是最小化I/O,通过扫描数据库,建立一棵存放于内存的初始CF-树,可以看做多数据的多层压缩。

算法原理

CF聚类特征

说到算法原理,首先就要先知道,什么是聚类特征,何为聚类特征,定义如下:

CF = <n, LS, SS>

聚类特征为一个3维向量,n为数据点总数,LS为n个点的线性和,SS为n个点的平方和。因此又可以得到

x0 = LS/n为簇的中心,以此计算簇与簇之间的距离。

簇内对象的平均距离簇直径,这个可以用阈值T限制,保证簇的一个整体的紧凑程度。簇和簇之间可以进行叠加,其实就是向量的叠加。

CF-Tree的构造过程

在介绍CF-Tree树,要先介绍3个变量,内部节点平衡因子B,叶节点平衡因子L,簇直径阈值T。B是用来限制非叶子节点的子节点数,L是用来限制叶子节点的子簇个数,T是用来限制簇的紧密程度的,比较的是D--簇内平均对象的距离。下面是主要的构造过程:

1、首先读入第一条数据,构造一个叶子节点和一个子簇,子簇包含在叶子节点中。

2、当读入后面的第2条,第3条,封装为一个簇,加入到一个叶子节点时,如果此时的待加入的簇C的簇直径已经大于T,则需要新建簇作为C的兄弟节点,如果作为兄弟节点,如果此时的叶子节点的孩子节点超过阈值L,则需对叶子节点进行分裂。分裂的规则是选出簇间距离最大的2个孩子,分别作为2个叶子,然后其他的孩子按照就近分配。非叶子节点的分裂规则同上。具体可以对照后面我写的代码。

3、最终的构造模样大致如此:


算法的优点:

1、算法只需扫描一遍就可以得到一个好的聚类效果,而且不需事先设定聚类个数。

2、聚类通过聚类特征树的形式,一定程度上保存了对数据的压缩。

算法的缺点:

1、该算法比较适合球形的簇,如果簇不是球形的,则聚簇的效果将不会很好。

算法的代码实现:

下面提供部分核心代码(如果想获取所有的代码,请点击我的数据挖掘代码):

数据的输入:

[java]  view plain copy print ?
  1. 5.1     3.5     1.4     0.2  
  2. 4.9     3.0     1.4     0.2  
  3. 4.7     3.2     1.3     0.8  
  4. 4.6     3.1     1.5     0.8  
  5. 5.0     3.6     1.8     0.6  
  6. 4.7     3.2     1.4     0.8  

ClusteringFeature.java:

[java]  view plain copy print ?
  1. package DataMining_BIRCH;  
  2.   
  3. import java.util.ArrayList;  
  4.   
  5. /** 
  6.  * 聚类特征基本属性 
  7.  *  
  8.  * @author lyq 
  9.  *  
  10.  */  
  11. public abstract class ClusteringFeature {  
  12.     // 子类中节点的总数目  
  13.     protected int N;  
  14.     // 子类中N个节点的线性和  
  15.     protected double[] LS;  
  16.     // 子类中N个节点的平方和  
  17.     protected double[] SS;  
  18.     //节点深度,用于CF树的输出  
  19.     protected int level;  
  20.   
  21.     public int getN() {  
  22.         return N;  
  23.     }  
  24.   
  25.     public void setN(int n) {  
  26.         N = n;  
  27.     }  
  28.   
  29.     public double[] getLS() {  
  30.         return LS;  
  31.     }  
  32.   
  33.     public void setLS(double[] lS) {  
  34.         LS = lS;  
  35.     }  
  36.   
  37.     public double[] getSS() {  
  38.         return SS;  
  39.     }  
  40.   
  41.     public void setSS(double[] sS) {  
  42.         SS = sS;  
  43.     }  
  44.   
  45.     protected void setN(ArrayList<double[]> dataRecords) {  
  46.         this.N = dataRecords.size();  
  47.     }  
  48.       
  49.     public int getLevel() {  
  50.         return level;  
  51.     }  
  52.   
  53.     public void setLevel(int level) {  
  54.         this.level = level;  
  55.     }  
  56.   
  57.     /** 
  58.      * 根据节点数据计算线性和 
  59.      *  
  60.      * @param dataRecords 
  61.      *            节点数据记录 
  62.      */  
  63.     protected void setLS(ArrayList<double[]> dataRecords) {  
  64.         int num = dataRecords.get(0).length;  
  65.         double[] record;  
  66.         LS = new double[num];  
  67.         for (int j = 0; j < num; j++) {  
  68.             LS[j] = 0;  
  69.         }  
  70.   
  71.         for (int i = 0; i < dataRecords.size(); i++) {  
  72.             record = dataRecords.get(i);  
  73.             for (int j = 0; j < record.length; j++) {  
  74.                 LS[j] += record[j];  
  75.             }  
  76.         }  
  77.     }  
  78.   
  79.     /** 
  80.      * 根据节点数据计算平方 
  81.      *  
  82.      * @param dataRecords 
  83.      *            节点数据 
  84.      */  
  85.     protected void setSS(ArrayList<double[]> dataRecords) {  
  86.         int num = dataRecords.get(0).length;  
  87.         double[] record;  
  88.         SS = new double[num];  
  89.         for (int j = 0; j < num; j++) {  
  90.             SS[j] = 0;  
  91.         }  
  92.   
  93.         for (int i = 0; i < dataRecords.size(); i++) {  
  94.             record = dataRecords.get(i);  
  95.             for (int j = 0; j < record.length; j++) {  
  96.                 SS[j] += record[j] * record[j];  
  97.             }  
  98.         }  
  99.     }  
  100.   
  101.     /** 
  102.      * CF向量特征的叠加,无须考虑划分 
  103.      *  
  104.      * @param node 
  105.      */  
  106.     protected void directAddCluster(ClusteringFeature node) {  
  107.         int N = node.getN();  
  108.         double[] otherLS = node.getLS();  
  109.         double[] otherSS = node.getSS();  
  110.           
  111.         if(LS == null){  
  112.             this.N = 0;  
  113.             LS = new double[otherLS.length];  
  114.             SS = new double[otherLS.length];  
  115.               
  116.             for(int i=0; i<LS.length; i++){  
  117.                 LS[i] = 0;  
  118.                 SS[i] = 0;  
  119.             }  
  120.         }  
  121.   
  122.         // 3个数量上进行叠加  
  123.         for (int i = 0; i < LS.length; i++) {  
  124.             LS[i] += otherLS[i];  
  125.             SS[i] += otherSS[i];  
  126.         }  
  127.         this.N += N;  
  128.     }  
  129.   
  130.     /** 
  131.      * 计算簇与簇之间的距离即簇中心之间的距离 
  132.      *  
  133.      * @return 
  134.      */  
  135.     protected double computerClusterDistance(ClusteringFeature cluster) {  
  136.         double distance = 0;  
  137.         double[] otherLS = cluster.LS;  
  138.         int num = N;  
  139.           
  140.         int otherNum = cluster.N;  
  141.   
  142.         for (int i = 0; i < LS.length; i++) {  
  143.             distance += (LS[i] / num - otherLS[i] / otherNum)  
  144.                     * (LS[i] / num - otherLS[i] / otherNum);  
  145.         }  
  146.         distance = Math.sqrt(distance);  
  147.   
  148.         return distance;  
  149.     }  
  150.   
  151.     /** 
  152.      * 计算簇内对象的平均距离 
  153.      *  
  154.      * @param records 
  155.      *            簇内的数据记录 
  156.      * @return 
  157.      */  
  158.     protected double computerInClusterDistance(ArrayList<double[]> records) {  
  159.         double sumDistance = 0;  
  160.         double[] data1;  
  161.         double[] data2;  
  162.         // 数据总数  
  163.         int totalNum = records.size();  
  164.   
  165.         for (int i = 0; i < totalNum - 1; i++) {  
  166.             data1 = records.get(i);  
  167.             for (int j = i + 1; j < totalNum; j++) {  
  168.                 data2 = records.get(j);  
  169.                 sumDistance += computeOuDistance(data1, data2);  
  170.             }  
  171.         }  
  172.   
  173.         // 返回的值除以总对数,总对数应减半,会重复算一次  
  174.         return Math.sqrt(sumDistance / (totalNum * (totalNum - 1) / 2));  
  175.     }  
  176.   
  177.     /** 
  178.      * 对给定的2个向量,计算欧式距离 
  179.      *  
  180.      * @param record1 
  181.      *            向量点1 
  182.      * @param record2 
  183.      *            向量点2 
  184.      */  
  185.     private double computeOuDistance(double[] record1, double[] record2) {  
  186.         double distance = 0;  
  187.   
  188.         for (int i = 0; i < record1.length; i++) {  
  189.             distance += (record1[i] - record2[i]) * (record1[i] - record2[i]);  
  190.         }  
  191.   
  192.         return distance;  
  193.     }  
  194.   
  195.     /** 
  196.      * 聚类添加节点包括,超出阈值进行分裂的操作 
  197.      *  
  198.      * @param clusteringFeature 
  199.      *            待添加聚簇 
  200.      */  
  201.     public abstract void addingCluster(ClusteringFeature clusteringFeature);  
  202. }  
BIRCHTool.java:

[java]  view plain copy print ?
  1. package DataMining_BIRCH;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.IOException;  
  7. import java.text.MessageFormat;  
  8. import java.util.ArrayList;  
  9. import java.util.LinkedList;  
  10.   
  11. /** 
  12.  * BIRCH聚类算法工具类 
  13.  *  
  14.  * @author lyq 
  15.  *  
  16.  */  
  17. public class BIRCHTool {  
  18.     // 节点类型名称  
  19.     public static final String NON_LEAFNODE = "【NonLeafNode】";  
  20.     public static final String LEAFNODE = "【LeafNode】";  
  21.     public static final String CLUSTER = "【Cluster】";  
  22.   
  23.     // 测试数据文件地址  
  24.     private String filePath;  
  25.     // 内部节点平衡因子B  
  26.     public static int B;  
  27.     // 叶子节点平衡因子L  
  28.     public static int L;  
  29.     // 簇直径阈值T  
  30.     public static double T;  
  31.     // 总的测试数据记录  
  32.     private ArrayList<String[]> totalDataRecords;  
  33.   
  34.     public BIRCHTool(String filePath, int B, int L, double T) {  
  35.         this.filePath = filePath;  
  36.         this.B = B;  
  37.         this.L = L;  
  38.         this.T = T;  
  39.         readDataFile();  
  40.     }  
  41.   
  42.     /** 
  43.      * 从文件中读取数据 
  44.      */  
  45.     private void readDataFile() {  
  46.         File file = new File(filePath);  
  47.         ArrayList<String[]> dataArray = new ArrayList<String[]>();  
  48.   
  49.         try {  
  50.             BufferedReader in = new BufferedReader(new FileReader(file));  
  51.             String str;  
  52.             String[] tempArray;  
  53.             while ((str = in.readLine()) != null) {  
  54.                 tempArray = str.split("     ");  
  55.                 dataArray.add(tempArray);  
  56.             }  
  57.             in.close();  
  58.         } catch (IOException e) {  
  59.             e.getStackTrace();  
  60.         }  
  61.   
  62.         totalDataRecords = new ArrayList<>();  
  63.         for (String[] array : dataArray) {  
  64.             totalDataRecords.add(array);  
  65.         }  
  66.     }  
  67.   
  68.     /** 
  69.      * 构建CF聚类特征树 
  70.      *  
  71.      * @return 
  72.      */  
  73.     private ClusteringFeature buildCFTree() {  
  74.         NonLeafNode rootNode = null;  
  75.         LeafNode leafNode = null;  
  76.         Cluster cluster = null;  
  77.   
  78.         for (String[] record : totalDataRecords) {  
  79.             cluster = new Cluster(record);  
  80.   
  81.             if (rootNode == null) {  
  82.                 // CF树只有1个节点的时候的情况  
  83.                 if (leafNode == null) {  
  84.                     leafNode = new LeafNode();  
  85.                 }  
  86.                 leafNode.addingCluster(cluster);  
  87.                 if (leafNode.getParentNode() != null) {  
  88.                     rootNode = leafNode.getParentNode();  
  89.                 }  
  90.             } else {  
  91.                 if (rootNode.getParentNode() != null) {  
  92.                     rootNode = rootNode.getParentNode();  
  93.                 }  
  94.   
  95.                 // 从根节点开始,从上往下寻找到最近的添加目标叶子节点  
  96.                 LeafNode temp = rootNode.findedClosestNode(cluster);  
  97.                 temp.addingCluster(cluster);  
  98.             }  
  99.         }  
  100.   
  101.         // 从下往上找出最上面的节点  
  102.         LeafNode node = cluster.getParentNode();  
  103.         NonLeafNode upNode = node.getParentNode();  
  104.         if (upNode == null) {  
  105.             return node;  
  106.         } else {  
  107.             while (upNode.getParentNode() != null) {  
  108.                 upNode = upNode.getParentNode();  
  109.             }  
  110.   
  111.             return upNode;  
  112.         }  
  113.     }  
  114.   
  115.     /** 
  116.      * 开始构建CF聚类特征树 
  117.      */  
  118.     public void startBuilding() {  
  119.         // 树深度  
  120.         int level = 1;  
  121.         ClusteringFeature rootNode = buildCFTree();  
  122.   
  123.         setTreeLevel(rootNode, level);  
  124.         showCFTree(rootNode);  
  125.     }  
  126.   
  127.     /** 
  128.      * 设置节点深度 
  129.      *  
  130.      * @param clusteringFeature 
  131.      *            当前节点 
  132.      * @param level 
  133.      *            当前深度值 
  134.      */  
  135.     private void setTreeLevel(ClusteringFeature clusteringFeature, int level) {  
  136.         LeafNode leafNode = null;  
  137.         NonLeafNode nonLeafNode = null;  
  138.   
  139.         if (clusteringFeature instanceof LeafNode) {  
  140.             leafNode = (LeafNode) clusteringFeature;  
  141.         } else if (clusteringFeature instanceof NonLeafNode) {  
  142.             nonLeafNode = (NonLeafNode) clusteringFeature;  
  143.         }  
  144.   
  145.         if (nonLeafNode != null) {  
  146.             nonLeafNode.setLevel(level);  
  147.             level++;  
  148.             // 设置子节点  
  149.             if (nonLeafNode.getNonLeafChilds() != null) {  
  150.                 for (NonLeafNode n1 : nonLeafNode.getNonLeafChilds()) {  
  151.                     setTreeLevel(n1, level);  
  152.                 }  
  153.             } else {  
  154.                 for (LeafNode n2 : nonLeafNode.getLeafChilds()) {  
  155.                     setTreeLevel(n2, level);  
  156.                 }  
  157.             }  
  158.         } else {  
  159.             leafNode.setLevel(level);  
  160.             level++;  
  161.             // 设置子聚簇  
  162.             for (Cluster c : leafNode.getClusterChilds()) {  
  163.                 c.setLevel(level);  
  164.             }  
  165.         }  
  166.     }  
  167.   
  168.     /** 
  169.      * 显示CF聚类特征树 
  170.      *  
  171.      * @param rootNode 
  172.      *            CF树根节点 
  173.      */  
  174.     private void showCFTree(ClusteringFeature rootNode) {  
  175.         // 空格数,用于输出  
  176.         int blankNum = 5;  
  177.         // 当前树深度  
  178.         int currentLevel = 1;  
  179.         LinkedList<ClusteringFeature> nodeQueue = new LinkedList<>();  
  180.         ClusteringFeature cf;  
  181.         LeafNode leafNode;  
  182.         NonLeafNode nonLeafNode;  
  183.         ArrayList<Cluster> clusterList = new ArrayList<>();  
  184.         String typeName;  
  185.   
  186.         nodeQueue.add(rootNode);  
  187.         while (nodeQueue.size() > 0) {  
  188.             cf = nodeQueue.poll();  
  189.   
  190.             if (cf instanceof LeafNode) {  
  191.                 leafNode = (LeafNode) cf;  
  192.                 typeName = LEAFNODE;  
  193.   
  194.                 if (leafNode.getClusterChilds() != null) {  
  195.                     for (Cluster c : leafNode.getClusterChilds()) {  
  196.                         nodeQueue.add(c);  
  197.                     }  
  198.                 }  
  199.             } else if (cf instanceof NonLeafNode) {  
  200.                 nonLeafNode = (NonLeafNode) cf;  
  201.                 typeName = NON_LEAFNODE;  
  202.   
  203.                 if (nonLeafNode.getNonLeafChilds() != null) {  
  204.                     for (NonLeafNode n1 : nonLeafNode.getNonLeafChilds()) {  
  205.                         nodeQueue.add(n1);  
  206.                     }  
  207.                 } else {  
  208.                     for (LeafNode n2 : nonLeafNode.getLeafChilds()) {  
  209.                         nodeQueue.add(n2);  
  210.                     }  
  211.                 }  
  212.             } else {  
  213.                 clusterList.add((Cluster)cf);  
  214.                 typeName = CLUSTER;  
  215.             }  
  216.   
  217.             if (currentLevel != cf.getLevel()) {  
  218.                 currentLevel = cf.getLevel();  
  219.                 System.out.println();  
  220.                 System.out.println("|");  
  221.                 System.out.println("|");  
  222.             }else if(currentLevel == cf.getLevel() && currentLevel != 1){  
  223.                 for (int i = 0; i < blankNum; i++) {  
  224.                     System.out.print("-");  
  225.                 }  
  226.             }  
  227.               
  228.             System.out.print(typeName);  
  229.             System.out.print("N:" + cf.getN() + ", LS:");  
  230.             System.out.print("[");  
  231.             for (double d : cf.getLS()) {  
  232.                 System.out.print(MessageFormat.format("{0}, ",  d));  
  233.             }  
  234.             System.out.print("]");  
  235.         }  
  236.           
  237.         System.out.println();  
  238.         System.out.println("*******最终分好的聚簇****");  
  239.         //显示已经分好类的聚簇点  
  240.         for(int i=0; i<clusterList.size(); i++){  
  241.             System.out.println("Cluster" + (i+1) + ":");  
  242.             for(double[] point: clusterList.get(i).getData()){  
  243.                 System.out.print("[");  
  244.                 for (double d : point) {  
  245.                     System.out.print(MessageFormat.format("{0}, ",  d));  
  246.                 }  
  247.                 System.out.println("]");  
  248.             }  
  249.         }  
  250.     }  
  251.   
  252. }  
由于代码量比较大,剩下的LeafNode.java,NonLeafNode.java, 和Cluster聚簇类可以在 我的数据挖掘代码 中查看。

结果输出:

[java]  view plain copy print ?
  1. 【NonLeafNode】N:6, LS:[2919.68.83.4, ]  
  2. |  
  3. |  
  4. 【LeafNode】N:3, LS:[149.54.22.4, ]-----【LeafNode】N:3, LS:[1510.14.61, ]  
  5. |  
  6. |  
  7. 【Cluster】N:3, LS:[149.54.22.4, ]-----【Cluster】N:1, LS:[53.61.80.6, ]-----【Cluster】N:2, LS:[106.52.80.4, ]  
  8. *******最终分好的聚簇****  
  9. Cluster1:  
  10. [4.73.21.30.8, ]  
  11. [4.63.11.50.8, ]  
  12. [4.73.21.40.8, ]  
  13. Cluster2:  
  14. [53.61.80.6, ]  
  15. Cluster3:  
  16. [5.13.51.40.2, ]  
  17. [4.931.40.2, ]  

算法实现时的难点

1、算簇间距离的时候,代了一下公式,发现不对劲,向量的运算不应该是这样的,于是就把他归与簇心之间的距离计算。还有簇内对象的平均距离也没有代入公式,网上的各种版本的向量计算,不知道哪种是对的,又按最原始的方式计算,一对对计算距离,求平均值。

2、算法在节点分裂的时候,如果父节点不为空,需要把自己从父亲中的孩子列表中移除,然后再添加分裂后的2个节点,这里的把自己移除掉容易忘记。

3、节点CF聚类特征值的更新,需要在每次节点的变化时,其所涉及的父类,父类的父类都需要更新,为此用了责任链模式,一个一个往上传,分裂的规则时也用了此模式,需要关注一下。

4、代码将CF聚类特征量进行抽象提取,定义了共有的方法,不过在实现时还是由于节点类型的不同,在实际的过程中需要转化。

5、最后的难点在与测试的复杂,因为程序经过千辛万苦的编写终于完成,但是如何测试时一个大问题,因为要把分裂的情况都测准,需要准确的把握T.,B.L,的设计,尤其是T簇直径,所以在捏造测试的时候自己也是经过很多的手动计算。

我对BIRCH算法的理解

在实现的整个完成的过程中 ,我对BIRCH算法的最大的感触就是通过聚类特征,一个新节点从根节点开始,从上往先寻找,离哪个簇近,就被分到哪个簇中,自发的形成了一个比较好的聚簇,这个过程是算法的神奇所在。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值