Weka算法Classifier-tree-J48源码分析(三)ModelSelection

ModelSelection主要是用于选择合适的列对数据集进行分割,结合上一篇J48的主流程,发现用到的ModelSelection有 C45ModelSelection以及BinC45ModelSelection,先来分析C45ModelSelection。


一、C45ModelSelection

首先作为一个ModelSelection接口,实现的主要方法有两个,分别是selectModel(Instances)和selectionModel(Instances,Instances)。C45ModelSelection的后一个方法如下:

[java]   view plain copy
  1. public final ClassifierSplitModel selectModel(Instances train, Instances test) {  
  2.   
  3.   return selectModel(train);  
  4. }  
可以看到就是忽略了test测试集直接调用selectModel方法而已,因此主要分词selectModel方法。

先放出整段代码,然后对该段代码进行分析:

[java]   view plain copy
  1. public final ClassifierSplitModel selectModel(Instances data){  
  2.   
  3.     double minResult;  
  4.     double currentResult;  
  5.     C45Split [] currentModel;  
  6.     C45Split bestModel = null;  
  7.     NoSplit noSplitModel = null;  
  8.     double averageInfoGain = 0;  
  9.     int validModels = 0;  
  10.     boolean multiVal = true;  
  11.     Distribution checkDistribution;  
  12.     Attribute attribute;  
  13.     double sumOfWeights;  
  14.     int i;  
  15.       
  16.     try{  
  17.   
  18.       // Check if all Instances belong to one class or if not  
  19.       // enough Instances to split.  
  20.       checkDistribution = new Distribution(data);  
  21.       noSplitModel = new NoSplit(checkDistribution);  
  22.       if (Utils.sm(checkDistribution.total(),2*m_minNoObj) ||  
  23.       Utils.eq(checkDistribution.total(),  
  24.            checkDistribution.perClass(checkDistribution.maxClass())))  
  25.     return noSplitModel;  
  26.   
  27.       // Check if all attributes are nominal and have a   
  28.       // lot of values.  
  29.       if (m_allData != null) {  
  30.     Enumeration enu = data.enumerateAttributes();  
  31.     while (enu.hasMoreElements()) {  
  32.       attribute = (Attribute) enu.nextElement();  
  33.       if ((attribute.isNumeric()) ||  
  34.           (Utils.sm((double)attribute.numValues(),  
  35.             (0.3*(double)m_allData.numInstances())))){  
  36.         multiVal = false;  
  37.         break;  
  38.       }  
  39.     }  
  40.       }   
  41.   
  42.       currentModel = new C45Split[data.numAttributes()];  
  43.       sumOfWeights = data.sumOfWeights();  
  44.   
  45.       // For each attribute.  
  46.       for (i = 0; i < data.numAttributes(); i++){  
  47.       
  48.     // Apart from class attribute.  
  49.     if (i != (data).classIndex()){  
  50.         
  51.       // Get models for current attribute.  
  52.       currentModel[i] = new C45Split(i,m_minNoObj,sumOfWeights);  
  53.       currentModel[i].buildClassifier(data);  
  54.         
  55.       // Check if useful split for current attribute  
  56.       // exists and check for enumerated attributes with   
  57.       // a lot of values.  
  58.       if (currentModel[i].checkModel())  
  59.         if (m_allData != null) {  
  60.           if ((data.attribute(i).isNumeric()) ||  
  61.           (multiVal || Utils.sm((double)data.attribute(i).numValues(),  
  62.                     (0.3*(double)m_allData.numInstances())))){  
  63.         averageInfoGain = averageInfoGain+currentModel[i].infoGain();  
  64.         validModels++;  
  65.           }   
  66.         } else {  
  67.           averageInfoGain = averageInfoGain+currentModel[i].infoGain();  
  68.           validModels++;  
  69.         }  
  70.     }else  
  71.       currentModel[i] = null;  
  72.       }  
  73.         
  74.       // Check if any useful split was found.  
  75.       if (validModels == 0)  
  76.     return noSplitModel;  
  77.       averageInfoGain = averageInfoGain/(double)validModels;  
  78.   
  79.       // Find "best" attribute to split on.  
  80.       minResult = 0;  
  81.       for (i=0;i<data.numAttributes();i++){  
  82.     if ((i != (data).classIndex()) &&  
  83.         (currentModel[i].checkModel()))  
  84.         
  85.       // Use 1E-3 here to get a closer approximation to the original  
  86.       // implementation.  
  87.       if ((currentModel[i].infoGain() >= (averageInfoGain-1E-3)) &&  
  88.           Utils.gr(currentModel[i].gainRatio(),minResult)){   
  89.         bestModel = currentModel[i];  
  90.         minResult = currentModel[i].gainRatio();  
  91.       }   
  92.       }  
  93.   
  94.       // Check if useful split was found.  
  95.       if (Utils.eq(minResult,0))  
  96.     return noSplitModel;  
  97.         
  98.       // Add all Instances with unknown values for the corresponding  
  99.       // attribute to the distribution for the model, so that  
  100.       // the complete distribution is stored with the model.   
  101.       bestModel.distribution().  
  102.       addInstWithUnknown(data,bestModel.attIndex());  
  103.         
  104.       // Set the split point analogue to C45 if attribute numeric.  
  105.       if (m_allData != null)  
  106.     bestModel.setSplitPoint(m_allData);  
  107.       return bestModel;  
  108.     }catch(Exception e){  
  109.       e.printStackTrace();  
  110.     }  
  111.     return null;  
  112.   }  
第一部分,主要是对局部变量的一些定义。

[java]   view plain copy
  1. double minResult;//最小的信息增益率  
  2. double currentResult;//当前信息增益率  
  3. C45Split [] currentModel;//存放所有未分类属性产生的模型  
  4. C45Split bestModel = null;//目前为止的最好模型  
  5. NoSplit noSplitModel = null;//代表不用分的模型  
  6. double averageInfoGain = 0;//各模型(currentModel)的平均信息增益  
  7. int validModels = 0;//是否存在有效模型  
  8. boolean multiVal = true;//是否多值  
  9. Distribution checkDistribution;//训练数据集的分布  
  10. Attribute attribute;//属性列集合  
  11. double sumOfWeights;//训练数据集的weight的和  
  12. int i;//循环变量  

第二部分,递归出口。

[java]   view plain copy
  1. checkDistribution = new Distribution(data);  
  2.      noSplitModel = new NoSplit(checkDistribution);  
  3.      if (Utils.sm(checkDistribution.total(),2*m_minNoObj) ||  
  4.   Utils.eq(checkDistribution.total(),  
  5.        checkDistribution.perClass(checkDistribution.maxClass())))  
  6. return noSplitModel;  
可以看到,如果当前数据集数量小于2*m_minNoObj(这个值默认是2),或者当前数据集已经全在同一个分类中,就返回noSplitModel代表不用分,这就是整个C45分类树节点停止分裂的条件。

第三部分,判断是否是多值:

[java]   view plain copy
  1.      if (m_allData != null) {  
  2. Enumeration enu = data.enumerateAttributes();  
  3. while (enu.hasMoreElements()) {  
  4.   attribute = (Attribute) enu.nextElement();  
  5.   if ((attribute.isNumeric()) ||  
  6.       (Utils.sm((double)attribute.numValues(),  
  7.         (0.3*(double)m_allData.numInstances())))){  
  8.     multiVal = false;  
  9.     break;  
  10.   }  
  11. }  
  12.      }   
如果属性中,任意一列是数值型,或者其取值的数量小于训练集数量*0.3,则不是多值,否则按多值处理。是否是多值影响到后面某些逻辑。

第四部分,对于每一列属性构造Spliter。

[java]   view plain copy
  1.    for (i = 0; i < data.numAttributes(); i++){  
  2.   
  3. // Apart from class attribute.  
  4. if (i != (data).classIndex()){  
  5.     
  6.   // Get models for current attribute.  
  7.   currentModel[i] = new C45Split(i,m_minNoObj,sumOfWeights);  
  8.   currentModel[i].buildClassifier(data);  
  9.     
  10.   // Check if useful split for current attribute  
  11.   // exists and check for enumerated attributes with   
  12.   // a lot of values.  
  13.   if (currentModel[i].checkModel())  
  14.     if (m_allData != null) {  
  15.       if ((data.attribute(i).isNumeric()) ||  
  16.       (multiVal || Utils.sm((double)data.attribute(i).numValues(),  
  17.                 (0.3*(double)m_allData.numInstances())))){  
  18.     averageInfoGain = averageInfoGain+currentModel[i].infoGain();  
  19.     validModels++;  
  20.       }   
  21.     } else {  
  22.       averageInfoGain = averageInfoGain+currentModel[i].infoGain();  
  23.       validModels++;  
  24.     }  
  25. }else  
  26.   currentModel[i] = null;  
  27.      }  

对于每一列属性,如果不是存放分类的值得话,则构造C45Split对象,在该对象上进行分类,然后算出信息增益,相加到averageInfoGain上。对于C45Split的构造,稍后再看。

第五部分,选出最优模型。

[java]   view plain copy
  1. if (validModels == 0)  
  2. return noSplitModel;  
  3.      averageInfoGain = averageInfoGain/(double)validModels;  
  4.   
  5.      // Find "best" attribute to split on.  
  6.      minResult = 0;  
  7.      for (i=0;i<data.numAttributes();i++){  
  8. if ((i != (data).classIndex()) &&  
  9.     (currentModel[i].checkModel()))  
  10.     
  11.   // Use 1E-3 here to get a closer approximation to the original  
  12.   // implementation.  
  13.   if ((currentModel[i].infoGain() >= (averageInfoGain-1E-3)) &&  
  14.       Utils.gr(currentModel[i].gainRatio(),minResult)){   
  15.     bestModel = currentModel[i];  
  16.     minResult = currentModel[i].gainRatio();  
  17.   }   

如果存在有效模型,则选出有效模型。注意这个选出最优模型的逻辑,并不是单纯的选出gainRatio最大的,而是在基础上必须还要大于平均信息增益,这也是和传统的c45算法不一样的一点。

从上述过程来看,Weka在实现C45的时候做了一个小的变动,并没有从“还没有使用的”属性列中找出最合理的列最为分割属性,而是在“所有的列”中找出最合理的列作为分割属性,虽然这二者在结果上肯定是等价的(之前是有过的属性不和能有很好的信息增益率),但效率上个人对Weka的做法持保留意见。


二、C45Spliter

在ModelSelection中真正根据属性对训练集进行分割、计算信息增益和信息增益率的是C45Spliter,首先也从其buildClassifier方法入手进行分析。

[java]   view plain copy
  1. public void buildClassifier(Instances trainInstances)   
  2.        throws Exception {  
  3.   
  4.     // Initialize the remaining instance variables.  
  5.     m_numSubsets = 0;  
  6.     m_splitPoint = Double.MAX_VALUE;  
  7.     m_infoGain = 0;  
  8.     m_gainRatio = 0;  
  9.   
  10.     // Different treatment for enumerated and numeric  
  11.     // attributes.  
  12.     if (trainInstances.attribute(m_attIndex).isNominal()) {  
  13.       m_complexityIndex = trainInstances.attribute(m_attIndex).numValues();  
  14.       m_index = m_complexityIndex;  
  15.       handleEnumeratedAttribute(trainInstances);  
  16.     }else{  
  17.       m_complexityIndex = 2;  
  18.       m_index = 0;  
  19.       trainInstances.sort(trainInstances.attribute(m_attIndex));  
  20.       handleNumericAttribute(trainInstances);  
  21.     }  
  22.   }      
可以看到,对于枚举型和数值型的属性是分开处理的,枚举型调用handlEnumeratedAttribute,数值型调用handleNumericAttribute,值得注意的是,在处理数值型之前,按照相应列进行排序,同时设置m_complexityIndex也就是期望分裂的节点数设定为2。

首先来看枚举类型是如何处理的。

[java]   view plain copy
  1. private void handleEnumeratedAttribute(Instances trainInstances)  
  2.        throws Exception {  
  3.       
  4.     Instance instance;  
  5.   
  6.     m_distribution = new Distribution(m_complexityIndex,  
  7.                   trainInstances.numClasses());  
  8.       
  9.     // Only Instances with known values are relevant.  
  10.     Enumeration enu = trainInstances.enumerateInstances();  
  11.     while (enu.hasMoreElements()) {  
  12.       instance = (Instance) enu.nextElement();  
  13.       if (!instance.isMissing(m_attIndex))  
  14.     m_distribution.add((int)instance.value(m_attIndex),instance);  
  15.     }  
  16.       
  17.     // Check if minimum number of Instances in at least two  
  18.     // subsets.  
  19.     if (m_distribution.check(m_minNoObj)) {  
  20.       m_numSubsets = m_complexityIndex;  
  21.       m_infoGain = infoGainCrit.  
  22.     splitCritValue(m_distribution,m_sumOfWeights);  
  23.       m_gainRatio =   
  24.     gainRatioCrit.splitCritValue(m_distribution,m_sumOfWeights,  
  25.                      m_infoGain);  
  26.     }  
  27.   }  
大概流程是新建一个分布,遍历所有instance,如果该instance对应的分裂的属性不为空的话,则放到不同的bag里,之后检查一下这个分布是否满足要求,要求就是最多允许有一个bag里的数据数量小于m_minNoObj,如果通过检查,就设置subset的数量,计算信息增益和信息增益率,否则subset默认会是0,上层调用checkModel就会返回false代表这是一个无效模型。

接下来看数值型是如何处理的:

[cpp]   view plain copy
  1. private void handleNumericAttribute(Instances trainInstances)  
  2.       throws Exception {  
  3.    
  4.    int firstMiss;//最后一个有效instance的下标  
  5.    int next = 1;//下一个instance的index  
  6.    int last = 0;//当前instance的index  
  7.    int splitIndex = -1;//分裂点  
  8.    double currentInfoGain;//当前信息增益  
  9.    double defaultEnt;//分割之前的信息熵  
  10.    double minSplit;  
  11.    Instance instance;  
  12.    int i;  
[cpp]   view plain copy
  1. //首先新建一个分布,数值型默认处理为2维分布,也就可以理解为小于某个值放到一个Bag里,其余的放到另外一个Bag里  
[cpp]   view plain copy
  1. m_distribution = new Distribution(2,trainInstances.numClasses());  
  2. Enumeration enu = trainInstances.enumerateInstances();  
  3. i = 0;  
[cpp]   view plain copy
  1. <pre name="code" class="cpp">//注意instances传入的时候是排好序的,这个排序保证了missingValue放在最后面,所以读到了missingValue其之后肯定都是miss//ingValue,换言之,firstMiss在循环之后代表了最后一个有效的instance的下标。  
while (enu.hasMoreElements()) { instance = (Instance) enu.nextElement(); if (instance.isMissing(m_attIndex))break; m_distribution.add(1,instance); i++; } firstMiss = i;//循环结束后,m_distribution里放入了所有的有效instance,并全放入了bag1里。
[cpp]   view plain copy
  1. //minSplit是最后分类好每个Bag里最小的数据的量,也就是0.1*每个类的均值。  
  2.     minSplit =  0.1*(m_distribution.total())/  
  3.       ((double)trainInstances.numClasses());  
  4.     if (Utils.smOrEq(minSplit,m_minNoObj))   
  5.       minSplit = m_minNoObj;  
  6.     else  
  7.       if (Utils.gr(minSplit,25))   
  8.     minSplit = 25;  
  9.       
  10. //如果有效数据总量不到2*minSplit,换言之无论怎么分均不能保证2个bag里的数量大于minSplit,就直接返回。  
  11.     if (Utils.sm((double)firstMiss,2*minSplit))  
  12.       return;  
  13.       
  14. //defaultEnt代表旧的信息熵,也就是对该属性进行分类之前,Indexclass对应的信息熵。  
  15.     defaultEnt = infoGainCrit.oldEnt(m_distribution);  
  16.     while (next < firstMiss) {  
  17.         
  18.       if (trainInstances.instance(next-1).value(m_attIndex)+1e-5 <   
  19.       trainInstances.instance(next).value(m_attIndex)) {   
  20.     <pre name="code" class="cpp">//Instances里的记录是升序排列的,加上这个条件默认把值相差很小的Instance就当做同一个instance处理了  
[cpp]   view plain copy
  1. //last代表当前,next代表下一个,默认next=1,last=0,所以shiftRange可以理解成把当前记录从bag1移动到bag0中  
[cpp]   view plain copy
  1. <span style="font-family: Arial, Helvetica, sans-serif;">//注意一开始初始化时候所有的都是在bag1里面的。   </span>  
m_distribution.shiftRange(1,0,trainInstances,last,next);if (Utils.grOrEq(m_distribution.perBag(0),minSplit) && //如果两个bag都满足最小数据集的数量minSplit Utils.grOrEq(m_distribution.perBag(1),minSplit)) { currentInfoGain = infoGainCrit. splitCritValue(m_distribution,m_sumOfWeights, //算一下信息增益 defaultEnt);
[cpp]   view plain copy
  1.   if (Utils.gr(currentInfoGain,m_infoGain)) {  
  2.     m_infoGain = currentInfoGain;//如果信息增益比当前最大的要大,则替换当前最大的值,并记录splitIndex  
  3.     splitIndex = next-1;  
  4.   }  
  5.   m_index++;  
  6. }  
  7. last = next;  
  8.      }  
  9.      next++;  
  10.    }  
  11.      
  12.    if (m_index == 0)  
  13.      return//执行到这里说明没找到一个合适的分裂点,直接返回。  
  14.      
  15.    // 计算最佳信息增益  
  16.    m_infoGain = m_infoGain-(Utils.log2(m_index)/m_sumOfWeights);  
  17.    if (Utils.smOrEq(m_infoGain,0))  
  18.      return//如果信息增益是0也说明没找到合适的分裂点,直接返回。  
  19.      
  20.    //剩下的就是根据分裂点进行属性的划分。  
  21.    m_numSubsets = 2;  
  22.    m_splitPoint =   
  23.      (trainInstances.instance(splitIndex+1).value(m_attIndex)+  
  24.       trainInstances.instance(splitIndex).value(m_attIndex))/2;  
  25.   
  26.    // In case we have a numerical precision problem we need to choose the  
  27.    // smaller value  
  28.    if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(m_attIndex)) {  
  29.      m_splitPoint = trainInstances.instance(splitIndex).value(m_attIndex);  
  30.    }  
  31.   
  32.    // Restore distributioN for best split.  
  33.    m_distribution = new Distribution(2,trainInstances.numClasses());  
  34.    m_distribution.addRange(0,trainInstances,0,splitIndex+1);  
  35.    m_distribution.addRange(1,trainInstances,splitIndex+1,firstMiss);  
  36.   
  37.    // Compute modified gain ratio for best split.  
  38.    m_gainRatio = gainRatioCrit.  
  39.      splitCritValue(m_distribution,m_sumOfWeights,  
  40.          m_infoGain);  
  41.  }  
这个函数有点复杂,具体逻辑也写到代码注释里了。


三、BinC45ModelSelection

该函数只负责生成二元分类树的模型,selectModel方法和C45ModelSelection几乎一样,不在多说,不同点在于其使用BinC45Spliter而不是C45Spliter。


四、BinC45Spliter

 handleNumericAttribute对于数值类型的属性处理和C45Spliter完全一样。下面只分析一下handleEnumeratedAttribute。

[java]   view plain copy
  1. private void handleEnumeratedAttribute(Instances trainInstances)  
  2.       throws Exception {  
  3.      
  4.    Distribution newDistribution,secondDistribution;  
  5.    int numAttValues;  
  6.    double currIG,currGR;  
  7.    Instance instance;  
  8.    int i;  
  9.   
  10.    numAttValues = trainInstances.attribute(m_attIndex).numValues();  
  11.    newDistribution = new Distribution(numAttValues,  
  12.                    trainInstances.numClasses());  
  13.      
  14.    // Only Instances with known values are relevant.  
  15.    Enumeration enu = trainInstances.enumerateInstances();  
  16.    while (enu.hasMoreElements()) {  
  17.      instance = (Instance) enu.nextElement();  
  18.      if (!instance.isMissing(m_attIndex))  
  19. newDistribution.add((int)instance.value(m_attIndex),instance);  
  20.    }  
  21.    m_distribution = newDistribution;  
  22.   
  23.    // For all values  
  24.    for (i = 0; i < numAttValues; i++){  
  25.   
  26.      if (Utils.grOrEq(newDistribution.perBag(i),m_minNoObj)){  
  27. secondDistribution = new Distribution(newDistribution,i);  
  28.   
  29. // Check if minimum number of Instances in the two  
  30. // subsets.  
  31. if (secondDistribution.check(m_minNoObj)){  
  32.   m_numSubsets = 2;  
  33.   currIG = m_infoGainCrit.splitCritValue(secondDistribution,  
  34.                        m_sumOfWeights);  
  35.   currGR = m_gainRatioCrit.splitCritValue(secondDistribution,  
  36.                     m_sumOfWeights,  
  37.                     currIG);  
  38.   if ((i == 0) || Utils.gr(currGR,m_gainRatio)){  
  39.     m_gainRatio = currGR;  
  40.     m_infoGain = currIG;  
  41.     m_splitPoint = (double)i;  
  42.     m_distribution = secondDistribution;  
  43.   }  
  44. }  
  45.      }  
  46.    }  
可以看出,上一段代码根据该属性的不同的取值,在已有分布基础上,建立一个新的分布secondeDistribution,
[java]   view plain copy
  1. secondDistribution = new Distribution(newDistribution,i);  
该分布包含两列,属性下标为i的,其余的,在这个分布的基础上计算信息增益和信息增益率,并选出最优的。

换句话说,离散值分类的二元化处理就是选出其中一列当做一个branch,其余的当做另外一个branch。虽然从结构上来讲这肯定不是最优的选择,但简单易用就够了。


到这里基本分析完了J48的两个ModelSelection,下一篇文章将对classifierInstance过程进行分析,并给出一个简单的总结。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值