ModelSelection主要是用于选择合适的列对数据集进行分割,结合上一篇J48的主流程,发现用到的ModelSelection有 C45ModelSelection以及BinC45ModelSelection,先来分析C45ModelSelection。
一、C45ModelSelection
首先作为一个ModelSelection接口,实现的主要方法有两个,分别是selectModel(Instances)和selectionModel(Instances,Instances)。C45ModelSelection的后一个方法如下:
public final ClassifierSplitModel selectModel(Instances train, Instances test) {
return selectModel(train);
}
可以看到就是忽略了test测试集直接调用selectModel方法而已,因此主要分词selectModel方法。
先放出整段代码,然后对该段代码进行分析:
public final ClassifierSplitModel selectModel(Instances data){
double minResult;
double currentResult;
C45Split [] currentModel;
C45Split bestModel = null;
NoSplit noSplitModel = null;
double averageInfoGain = 0;
int validModels = 0;
boolean multiVal = true;
Distribution checkDistribution;
Attribute attribute;
double sumOfWeights;
int i;
try{
// Check if all Instances belong to one class or if not
// enough Instances to split.
checkDistribution = new Distribution(data);
noSplitModel = new NoSplit(checkDistribution);
if (Utils.sm(checkDistribution.total(),2*m_minNoObj) ||
Utils.eq(checkDistribution.total(),
checkDistribution.perClass(checkDistribution.maxClass())))
return noSplitModel;
// Check if all attributes are nominal and have a
// lot of values.
if (m_allData != null) {
Enumeration enu = data.enumerateAttributes();
while (enu.hasMoreElements()) {
attribute = (Attribute) enu.nextElement();
if ((attribute.isNumeric()) ||
(Utils.sm((double)attribute.numValues(),
(0.3*(double)m_allData.numInstances())))){
multiVal = false;
break;
}
}
}
currentModel = new C45Split[data.numAttributes()];
sumOfWeights = data.sumOfWeights();
// For each attribute.
for (i = 0; i < data.numAttributes(); i++){
// Apart from class attribute.
if (i != (data).classIndex()){
// Get models for current attribute.
currentModel[i] = new C45Split(i,m_minNoObj,sumOfWeights);
currentModel[i].buildClassifier(data);
// Check if useful split for current attribute
// exists and check for enumerated attributes with
// a lot of values.
if (currentModel[i].checkModel())
if (m_allData != null) {
if ((data.attribute(i).isNumeric()) ||
(multiVal || Utils.sm((double)data.attribute(i).numValues(),
(0.3*(double)m_allData.numInstances())))){
averageInfoGain = averageInfoGain+currentModel[i].infoGain();
validModels++;
}
} else {
averageInfoGain = averageInfoGain+currentModel[i].infoGain();
validModels++;
}
}else
currentModel[i] = null;
}
// Check if any useful split was found.
if (validModels == 0)
return noSplitModel;
averageInfoGain = averageInfoGain/(double)validModels;
// Find "best" attribute to split on.
minResult = 0;
for (i=0;i<data.numAttributes();i++){
if ((i != (data).classIndex()) &&
(currentModel[i].checkModel()))
// Use 1E-3 here to get a closer approximation to the original
// implementation.
if ((currentModel[i].infoGain() >= (averageInfoGain-1E-3)) &&
Utils.gr(currentModel[i].gainRatio(),minResult)){
bestModel = currentModel[i];
minResult = currentModel[i].gainRatio();
}
}
// Check if useful split was found.
if (Utils.eq(minResult,0))
return noSplitModel;
// Add all Instances with unknown values for the corresponding
// attribute to the distribution for the model, so that
// the complete distribution is stored with the model.
bestModel.distribution().
addInstWithUnknown(data,bestModel.attIndex());
// Set the split point analogue to C45 if attribute numeric.
if (m_allData != null)
bestModel.setSplitPoint(m_allData);
return bestModel;
}catch(Exception e){
e.printStackTrace();
}
return null;
}
第一部分,主要是对局部变量的一些定义。
double minResult;//最小的信息增益率
double currentResult;//当前信息增益率
C45Split [] currentModel;//存放所有未分类属性产生的模型
C45Split bestModel = null;//目前为止的最好模型
NoSplit noSplitModel = null;//代表不用分的模型
double averageInfoGain = 0;//各模型(currentModel)的平均信息增益
int validModels = 0;//是否存在有效模型
boolean multiVal = true;//是否多值
Distribution checkDistribution;//训练数据集的分布
Attribute attribute;//属性列集合
double sumOfWeights;//训练数据集的weight的和
int i;//循环变量
第二部分,递归出口。
checkDistribution = new Distribution(data);
noSplitModel = new NoSplit(checkDistribution);
if (Utils.sm(checkDistribution.total(),2*m_minNoObj) ||
Utils.eq(checkDistribution.total(),
checkDistribution.perClass(checkDistribution.maxClass())))
return noSplitModel;
可以看到,如果当前数据集数量小于2*m_minNoObj(这个值默认是2),或者当前数据集已经全在同一个分类中,就返回noSplitModel代表不用分,这就是整个C45分类树节点停止分裂的条件。
第三部分,判断是否是多值:
if (m_allData != null) {
Enumeration enu = data.enumerateAttributes();
while (enu.hasMoreElements()) {
attribute = (Attribute) enu.nextElement();
if ((attribute.isNumeric()) ||
(Utils.sm((double)attribute.numValues(),
(0.3*(double)m_allData.numInstances())))){
multiVal = false;
break;
}
}
}
如果属性中,任意一列是数值型,或者其取值的数量小于训练集数量*0.3,则不是多值,否则按多值处理。是否是多值影响到后面某些逻辑。
第四部分,对于每一列属性构造Spliter。
for (i = 0; i < data.numAttributes(); i++){
// Apart from class attribute.
if (i != (data).classIndex()){
// Get models for current attribute.
currentModel[i] = new C45Split(i,m_minNoObj,sumOfWeights);
currentModel[i].buildClassifier(data);
// Check if useful split for current attribute
// exists and check for enumerated attributes with
// a lot of values.
if (currentModel[i].checkModel())
if (m_allData != null) {
if ((data.attribute(i).isNumeric()) ||
(multiVal || Utils.sm((double)data.attribute(i).numValues(),
(0.3*(double)m_allData.numInstances())))){
averageInfoGain = averageInfoGain+currentModel[i].infoGain();
validModels++;
}
} else {
averageInfoGain = averageInfoGain+currentModel[i].infoGain();
validModels++;
}
}else
currentModel[i] = null;
}
对于每一列属性,如果不是存放分类的值得话,则构造C45Split对象,在该对象上进行分类,然后算出信息增益,相加到averageInfoGain上。对于C45Split的构造,稍后再看。
第五部分,选出最优模型。
if (validModels == 0)
return noSplitModel;
averageInfoGain = averageInfoGain/(double)validModels;
// Find "best" attribute to split on.
minResult = 0;
for (i=0;i<data.numAttributes();i++){
if ((i != (data).classIndex()) &&
(currentModel[i].checkModel()))
// Use 1E-3 here to get a closer approximation to the original
// implementation.
if ((currentModel[i].infoGain() >= (averageInfoGain-1E-3)) &&
Utils.gr(currentModel[i].gainRatio(),minResult)){
bestModel = currentModel[i];
minResult = currentModel[i].gainRatio();
}
如果存在有效模型,则选出有效模型。注意这个选出最优模型的逻辑,并不是单纯的选出gainRatio最大的,而是在基础上必须还要大于平均信息增益,这也是和传统的c45算法不一样的一点。
从上述过程来看,Weka在实现C45的时候做了一个小的变动,并没有从“还没有使用的”属性列中找出最合理的列最为分割属性,而是在“所有的列”中找出最合理的列作为分割属性,虽然这二者在结果上肯定是等价的(之前是有过的属性不和能有很好的信息增益率),但效率上个人对Weka的做法持保留意见。
二、C45Spliter
在ModelSelection中真正根据属性对训练集进行分割、计算信息增益和信息增益率的是C45Spliter,首先也从其buildClassifier方法入手进行分析。
public void buildClassifier(Instances trainInstances)
throws Exception {
// Initialize the remaining instance variables.
m_numSubsets = 0;
m_splitPoint = Double.MAX_VALUE;
m_infoGain = 0;
m_gainRatio = 0;
// Different treatment for enumerated and numeric
// attributes.
if (trainInstances.attribute(m_attIndex).isNominal()) {
m_complexityIndex = trainInstances.attribute(m_attIndex).numValues();
m_index = m_complexityIndex;
handleEnumeratedAttribute(trainInstances);
}else{
m_complexityIndex = 2;
m_index = 0;
trainInstances.sort(trainInstances.attribute(m_attIndex));
handleNumericAttribute(trainInstances);
}
}
可以看到,对于枚举型和数值型的属性是分开处理的,枚举型调用handlEnumeratedAttribute,数值型调用handleNumericAttribute,值得注意的是,在处理数值型之前,按照相应列进行排序,同时设置m_complexityIndex也就是期望分裂的节点数设定为2。
首先来看枚举类型是如何处理的。
private void handleEnumeratedAttribute(Instances trainInstances)
throws Exception {
Instance instance;
m_distribution = new Distribution(m_complexityIndex,
trainInstances.numClasses());
// Only Instances with known values are relevant.
Enumeration enu = trainInstances.enumerateInstances();
while (enu.hasMoreElements()) {
instance = (Instance) enu.nextElement();
if (!instance.isMissing(m_attIndex))
m_distribution.add((int)instance.value(m_attIndex),instance);
}
// Check if minimum number of Instances in at least two
// subsets.
if (m_distribution.check(m_minNoObj)) {
m_numSubsets = m_complexityIndex;
m_infoGain = infoGainCrit.
splitCritValue(m_distribution,m_sumOfWeights);
m_gainRatio =
gainRatioCrit.splitCritValue(m_distribution,m_sumOfWeights,
m_infoGain);
}
}
大概流程是新建一个分布,遍历所有instance,如果该instance对应的分裂的属性不为空的话,则放到不同的bag里,之后检查一下这个分布是否满足要求,要求就是最多允许有一个bag里的数据数量小于m_minNoObj,如果通过检查,就设置subset的数量,计算信息增益和信息增益率,否则subset默认会是0,上层调用checkModel就会返回false代表这是一个无效模型。
接下来看数值型是如何处理的:
private void handleNumericAttribute(Instances trainInstances)
throws Exception {
int firstMiss;//最后一个有效instance的下标
int next = 1;//下一个instance的index
int last = 0;//当前instance的index
int splitIndex = -1;//分裂点
double currentInfoGain;//当前信息增益
double defaultEnt;//分割之前的信息熵
double minSplit;
Instance instance;
int i;
//首先新建一个分布,数值型默认处理为2维分布,也就可以理解为小于某个值放到一个Bag里,其余的放到另外一个Bag里
m_distribution = new Distribution(2,trainInstances.numClasses());
Enumeration enu = trainInstances.enumerateInstances();
i = 0;
<pre name="code" class="cpp">//注意instances传入的时候是排好序的,这个排序保证了missingValue放在最后面,所以读到了missingValue其之后肯定都是miss//ingValue,换言之,firstMiss在循环之后代表了最后一个有效的instance的下标。
while (enu.hasMoreElements()) { instance = (Instance) enu.nextElement(); if (instance.isMissing(m_attIndex))break; m_distribution.add(1,instance); i++; } firstMiss = i;//循环结束后,m_distribution里放入了所有的有效instance,并全放入了bag1里。
//minSplit是最后分类好每个Bag里最小的数据的量,也就是0.1*每个类的均值。
minSplit = 0.1*(m_distribution.total())/
((double)trainInstances.numClasses());
if (Utils.smOrEq(minSplit,m_minNoObj))
minSplit = m_minNoObj;
else
if (Utils.gr(minSplit,25))
minSplit = 25;
//如果有效数据总量不到2*minSplit,换言之无论怎么分均不能保证2个bag里的数量大于minSplit,就直接返回。
if (Utils.sm((double)firstMiss,2*minSplit))
return;
//defaultEnt代表旧的信息熵,也就是对该属性进行分类之前,Indexclass对应的信息熵。
defaultEnt = infoGainCrit.oldEnt(m_distribution);
while (next < firstMiss) {
if (trainInstances.instance(next-1).value(m_attIndex)+1e-5 <
trainInstances.instance(next).value(m_attIndex)) {
<pre name="code" class="cpp">//Instances里的记录是升序排列的,加上这个条件默认把值相差很小的Instance就当做同一个instance处理了
//last代表当前,next代表下一个,默认next=1,last=0,所以shiftRange可以理解成把当前记录从bag1移动到bag0中
<span style="font-family: Arial, Helvetica, sans-serif;">//注意一开始初始化时候所有的都是在bag1里面的。 </span>
m_distribution.shiftRange(1,0,trainInstances,last,next);if (Utils.grOrEq(m_distribution.perBag(0),minSplit) && //如果两个bag都满足最小数据集的数量minSplit Utils.grOrEq(m_distribution.perBag(1),minSplit)) { currentInfoGain = infoGainCrit. splitCritValue(m_distribution,m_sumOfWeights, //算一下信息增益 defaultEnt);
if (Utils.gr(currentInfoGain,m_infoGain)) {
m_infoGain = currentInfoGain;//如果信息增益比当前最大的要大,则替换当前最大的值,并记录splitIndex
splitIndex = next-1;
}
m_index++;
}
last = next;
}
next++;
}
if (m_index == 0)
return; //执行到这里说明没找到一个合适的分裂点,直接返回。
// 计算最佳信息增益
m_infoGain = m_infoGain-(Utils.log2(m_index)/m_sumOfWeights);
if (Utils.smOrEq(m_infoGain,0))
return; //如果信息增益是0也说明没找到合适的分裂点,直接返回。
//剩下的就是根据分裂点进行属性的划分。
m_numSubsets = 2;
m_splitPoint =
(trainInstances.instance(splitIndex+1).value(m_attIndex)+
trainInstances.instance(splitIndex).value(m_attIndex))/2;
// In case we have a numerical precision problem we need to choose the
// smaller value
if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(m_attIndex)) {
m_splitPoint = trainInstances.instance(splitIndex).value(m_attIndex);
}
// Restore distributioN for best split.
m_distribution = new Distribution(2,trainInstances.numClasses());
m_distribution.addRange(0,trainInstances,0,splitIndex+1);
m_distribution.addRange(1,trainInstances,splitIndex+1,firstMiss);
// Compute modified gain ratio for best split.
m_gainRatio = gainRatioCrit.
splitCritValue(m_distribution,m_sumOfWeights,
m_infoGain);
}
这个函数有点复杂,具体逻辑也写到代码注释里了。
三、BinC45ModelSelection
该函数只负责生成二元分类树的模型,selectModel方法和C45ModelSelection几乎一样,不在多说,不同点在于其使用BinC45Spliter而不是C45Spliter。
四、BinC45Spliter
handleNumericAttribute对于数值类型的属性处理和C45Spliter完全一样。下面只分析一下handleEnumeratedAttribute。
private void handleEnumeratedAttribute(Instances trainInstances)
throws Exception {
Distribution newDistribution,secondDistribution;
int numAttValues;
double currIG,currGR;
Instance instance;
int i;
numAttValues = trainInstances.attribute(m_attIndex).numValues();
newDistribution = new Distribution(numAttValues,
trainInstances.numClasses());
// Only Instances with known values are relevant.
Enumeration enu = trainInstances.enumerateInstances();
while (enu.hasMoreElements()) {
instance = (Instance) enu.nextElement();
if (!instance.isMissing(m_attIndex))
newDistribution.add((int)instance.value(m_attIndex),instance);
}
m_distribution = newDistribution;
// For all values
for (i = 0; i < numAttValues; i++){
if (Utils.grOrEq(newDistribution.perBag(i),m_minNoObj)){
secondDistribution = new Distribution(newDistribution,i);
// Check if minimum number of Instances in the two
// subsets.
if (secondDistribution.check(m_minNoObj)){
m_numSubsets = 2;
currIG = m_infoGainCrit.splitCritValue(secondDistribution,
m_sumOfWeights);
currGR = m_gainRatioCrit.splitCritValue(secondDistribution,
m_sumOfWeights,
currIG);
if ((i == 0) || Utils.gr(currGR,m_gainRatio)){
m_gainRatio = currGR;
m_infoGain = currIG;
m_splitPoint = (double)i;
m_distribution = secondDistribution;
}
}
}
}
可以看出,上一段代码根据该属性的不同的取值,在已有分布基础上,建立一个新的分布secondeDistribution,
secondDistribution = new Distribution(newDistribution,i);
该分布包含两列,属性下标为i的,其余的,在这个分布的基础上计算信息增益和信息增益率,并选出最优的。
换句话说,离散值分类的二元化处理就是选出其中一列当做一个branch,其余的当做另外一个branch。虽然从结构上来讲这肯定不是最优的选择,但简单易用就够了。
到这里基本分析完了J48的两个ModelSelection,下一篇文章将对classifierInstance过程进行分析,并给出一个简单的总结。