Weka算法Classifier-tree-RandomForest源码分析(二)代码实现

                           Weka算法Classifier-tree-RandomForest源码分析(二)代码实现


RandomForest的实现异常的简单,简单的超出博主的预期,Weka在实现方式上组合了Bagging和RandomTree。


一、RandomForest的训练

构建RandomForest的代码如下:

[java]  view plain  copy
  1. public void buildClassifier(Instances data) throws Exception {  
  2.   
  3.   // can classifier handle the data?  
  4.   getCapabilities().testWithFail(data);  
  5.   
  6.   // remove instances with missing class  
  7.   data = new Instances(data);  
  8.   data.deleteWithMissingClass();  
  9.   
  10.   m_bagger = new Bagging();  
  11.   RandomTree rTree = new RandomTree();  
  12.   
  13.   // set up the random tree options  
  14.   m_KValue = m_numFeatures;  
  15.   if (m_KValue < 1)  
  16.     m_KValue = (int) Utils.log2(data.numAttributes()) + 1;  
  17.   rTree.setKValue(m_KValue);  
  18.   rTree.setMaxDepth(getMaxDepth());  
  19.   
  20.   // set up the bagger and build the forest  
  21.   m_bagger.setClassifier(rTree);  
  22.   m_bagger.setSeed(m_randomSeed);  
  23.   m_bagger.setNumIterations(m_numTrees);  
  24.   m_bagger.setCalcOutOfBag(true);  
  25.   m_bagger.buildClassifier(data);  
  26. }  
通过这段代码很直观的可以看出首先把无效数据去掉,然后建立了一个Bag,设置随机森林中每棵树所用到的属性的值,设置最大深度,接着把这棵RandomTree当做基分类器传递给Bagging,最后调用bagging的训练方法进行训练。


二、RandomForest分类

看完训练过程看具体的分类过程,也就是classifyInstance函数,值得注意的是,RandomForest继承自Classifier,却没有队classifyInstance方法进行重载,使用的是基类Classifier的classifyInstance函数,但却重载了distributionForInstance,而distributionForInstance却是Classifier的classifyInstance函数所用到的一个函数,返回一个instance在所有类上的概率。代码如下:

[java]  view plain  copy
  1. public double[] distributionForInstance(Instance instance) throws Exception {  
  2.   
  3.   return m_bagger.distributionForInstance(instance);  
  4. }  
可以看到,算出给定instance在各class上的分布是委托给bagger去做的(真懒),所以这里也不做详细分析,详细分析留到分析bagger的时候再说。

接下来看基类Classifier是如何使用distribution来给出分类结果的。

[java]  view plain  copy
  1. public double classifyInstance(Instance instance) throws Exception {  
  2.   
  3.   double[] dist = distributionForInstance(instance);  
  4.   if (dist == null) {  
  5.     throw new Exception("Null distribution predicted");  
  6.   }  
  7.   switch (instance.classAttribute().type()) {  
  8.   case Attribute.NOMINAL:  
  9.     double max = 0;  
  10.     int maxIndex = 0;  
  11.   
  12.     for (int i = 0; i < dist.length; i++) {  
  13.       if (dist[i] > max) {  
  14.         maxIndex = i;  
  15.         max = dist[i];  
  16.       }  
  17.     }  
  18.     if (max > 0) {  
  19.       return maxIndex;  
  20.     } else {  
  21.       return Instance.missingValue();  
  22.     }  
  23.   case Attribute.NUMERIC:  
  24.   case Attribute.DATE:  
  25.     return dist[0];  
  26.   default:  
  27.     return Instance.missingValue();  
  28.   }  
  29. }  
可以很直观的看到,如果要是一个分类,则给出概率最大值,如果是一个回归(即classIndex对应的属性是数值),则返回dist[0],这里是使用了一个约定,第一个元素代表回归值。


三、总结

对于RandomForest的代码分析差不多就结束了,基本没什么实质内容,因为算法的主要工作都交由Bagging和RandomForest去做了,值得注意的是,当没有指定抽样属性的数量时,Weka使用的log2(K)作为经验值。


评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值