Weka开发[11]—J48源代码介绍

转自 Koala++'s blog  感谢原作者

 

这次介绍一下J48的源码,分析J48的源码似乎真还是有用的,同学改造J48写过VFDT,我自己用J48进行特征选择(当然很失败)。

J48buildClassfier函数:

publicvoid buildClassifier(Instances instances) throws Exception {

ModelSelection modSelection;

if (m_binarySplits)

modSelection = new BinC45ModelSelection(m_minNumObj, instances);

else

modSelection = new C45ModelSelection(m_minNumObj, instances);

if (!m_reducedErrorPruning)

m_root = new C45PruneableClassifierTree(modSelection,

!m_unpruned, m_CF, m_subtreeRaising, !m_noCleanup);

else

m_root = new PruneableClassifierTree(modSelection, !m_unpruned,

m_numFolds, !m_noCleanup, m_Seed);

m_root.buildClassifier(instances);

if (m_binarySplits) {

((BinC45ModelSelection) modSelection).cleanup();

} else {

((C45ModelSelection) modSelection).cleanup();

}

}

NBTree中已经介绍过了,ModelSelection是决定决策树的模型类,前面两个if,一个是判断连续属性是否只分出两个子结点,另一个判断是否最后剪枝。m_root是一个ClassifierTree对象,它调用buildClassifier函数。这里列出这个函数:

publicvoid buildClassifier(Instances data) throws Exception {

// can classifier tree handle the data?

getCapabilities().testWithFail(data);

// remove instances with missing class

data = new Instances(data);

data.deleteWithMissingClass();

buildTree(data, false);

}

有注释也没什么好说的,直接看最后一个函数buildTree

publicvoid buildTree(Instances data, boolean keepData) throws Exception {

Instances[] localInstances;

if (keepData) {

m_train = data;

}

m_test = null;

m_isLeaf = false;

m_isEmpty = false;

m_sons = null;

m_localModel = m_toSelectModel.selectModel(data);

if (m_localModel.numSubsets() > 1) {

localInstances = m_localModel.split(data);

data = null;

m_sons = new ClassifierTree[m_localModel.numSubsets()];

for (int i = 0; i < m_sons.length; i++) {

m_sons[i] = getNewTree(localInstances[i]);

localInstances[i] = null;

}

} else {

m_isLeaf = true;

if (Utils.eq(data.sumOfWeights(), 0))

m_isEmpty = true;

data = null;

}

}

这里的selectModel函数,如果看过NBTree一篇的读者应该不会太陌生,selectModel简单地说就是如果不符合分裂的条件就返回NoSplit,如果符合分裂的条件,则从currentModel数组中选出bestModel返回。

这最要注意的是selectModel也不只是决定哪个属性分裂,其实到底如何分裂已经在这个函数里算里出来了。

我把selectModel拆开来讲解

// Check if all Instances belong to one class or if not

// enough Instances to split.

checkDistribution = new Distribution(data);

noSplitModel = new NoSplit(checkDistribution);

if (Utils.sm(checkDistribution.total(), 2 * m_minNoObj)

|| Utils.eq(checkDistribution.total(), checkDistribution

.perClass(checkDistribution.maxClass())))

return noSplitModel;

2 * m_minNoObj表示至有有这么多样本才可以分裂,原因很简单,因为一个结点至少分出两个子结点,每个子结点至少有m_minNoObj个样本,第二个是或条件是表示是否这个结点上所有的样本都属于同一类别,也就是这个结点总的权重是否等于这个最多类别的权重。

// Check if all attributes are nominal and have a lot of values.

if (m_allData != null) {

Enumeration enu = data.enumerateAttributes();

while (enu.hasMoreElements()) {

attribute = (Attribute) enu.nextElement();

if ((attribute.isNumeric())

|| (Utils.sm((double) attribute.numValues(),

(0.3 * (double) m_allData.numInstances())))) {

multiVal = false;

break;

    }

}

}

判断是否有很多不同的属性值,标准就是如果有一个属性的属性值小多于总样本数*0.3,那么就是不是multiVal

currentModel = new C45Split[data.numAttributes()];

sumOfWeights = data.sumOfWeights();

// For each attribute.

for (i = 0; i < data.numAttributes(); i++) {

// Apart from class attribute.

if (i != (data).classIndex()) {

// Get models for current attribute.

currentModel[i] = new C45Split(i, m_minNoObj, sumOfWeights);

currentModel[i].buildClassifier(data);

// Check if useful split for current attribute

// exists and check for enumerated attributes with

// a lot of values.

if (currentModel[i].checkModel())

if (m_allData != null) {

if ((data.attribute(i).isNumeric())

|| (multiVal || Utils.sm((double) data

.attribute(i).numValues(),

(0.3 * (double) m_allData.numInstances())))) {

averageInfoGain = averageInfoGain

+ currentModel[i].infoGain();

validModels++;

}

} else {

averageInfoGain = averageInfoGain

+ currentModel[i].infoGain();

validModels++;

    }

    } else

currentModel[i] = null;

}

里面重要的两句就是:

// Get models for current attribute.

currentModel[i] = new C45Split(i, m_minNoObj, sumOfWeights);

currentModel[i].buildClassifier(data);

其它的也没有什么,求一下averageInfoGainvalidModelscheckModel如果可以分出子结点则为真。

这里是C45Split类的成员函数buildClassfier被调用,列出它的代码:

publicvoid buildClassifier(Instances trainInstances) throws Exception {

// Initialize the remaining instance variables.

m_numSubsets = 0;

m_splitPoint = Double.MAX_VALUE;

m_infoGain = 0;

m_gainRatio = 0;

// Different treatment for enumerated and numeric

// attributes.

if (trainInstances.attribute(m_attIndex).isNominal()) {

    m_complexityIndex = trainInstances.attribute(m_attIndex)

.numValues();

m_index = m_complexityIndex;

handleEnumeratedAttribute(trainInstances);

}else{

m_complexityIndex = 2;

m_index = 0;

trainInstances.sort(trainInstances.attribute(m_attIndex));

handleNumericAttribute(trainInstances);

}

}  

这里handleEnumerateAttributehandleNumericAttribute是决定到底是哪一个属性分裂(m_attIndex)和分裂出几个子结点的函数(m_numSubsets)。这里的m_comlexity就是指分可以分裂出多少子结点。如果是连续属性就是2。再看一下handleEnumeratedAttribute函数:

privatevoid handleEnumeratedAttribute(Instances trainInstances)

throws Exception {

Instance instance;

m_distribution = new Distribution(m_complexityIndex,

trainInstances.numClasses());

// Only Instances with known values are relevant.

Enumeration enu = trainInstances.enumerateInstances();

while (enu.hasMoreElements()) {

    instance = (Instance) enu.nextElement();

    if (!instance.isMissing(m_attIndex))

m_distribution.add((int) instance.value(m_attIndex),

instance);

}

// Check if minimum number of Instances in at least two

// subsets.

if (m_distribution.check(m_minNoObj)) {

m_numSubsets = m_complexityIndex;

m_infoGain = infoGainCrit.splitCritValue(m_distribution,

m_sumOfWeights);

m_gainRatio = gainRatioCrit.splitCritValue(m_distribution,

m_sumOfWeights, m_infoGain);

}

}

// Current attribute is a numeric attribute.

m_distribution = new Distribution(2, trainInstances.numClasses());

// Only Instances with known values are relevant.

Enumeration enu = trainInstances.enumerateInstances();

i = 0;

while (enu.hasMoreElements()) {

instance = (Instance) enu.nextElement();

if (instance.isMissing(m_attIndex))

break;

m_distribution.add(1, instance);

i++;

}

firstMiss = i;

已经讲过了,如果是连续属性就分出两个子结点,也就是Distribution的第一个参数。枚举所有样本,因为在调用HandleNumericAttribute之间已经对数据集根据m_attIndex排序过,所以缺失数据都在最后。也就是firstMiss是在m_attIndex上有确定值的样本个数+1。在while循环中,把所有的样本都先放到bag 1(add(1,instance))。还是列出来一下吧。

publicfinalvoid add(int bagIndex, Instance instance) throws Exception {

int classIndex;

double weight;

classIndex = (int) instance.classValue();

weight = instance.weight();

m_perClassPerBag[bagIndex][classIndex] =

m_perClassPerBag[bagIndex][classIndex] + weight;

m_perBag[bagIndex] = m_perBag[bagIndex] + weight;

m_perClass[classIndex] = m_perClass[classIndex] + weight;

totaL = totaL + weight;

}

也就这个函数也就是根据参数bagIndex和样本的类别值classIndex,三个成员变量m_perBag, m_perClass, m_perClassPerBag分别加上样本的权重。

// Compute minimum number of Instances required in each subset.

minSplit = 0.1 * (m_distribution.total())

/ ((double) trainInstances.numClasses());

if (Utils.smOrEq(minSplit, m_minNoObj))

minSplit = m_minNoObj;

elseif (Utils.gr(minSplit, 25))

minSplit = 25;

// Enough Instances with known values?

if (Utils.sm((double) firstMiss, 2 * minSplit))

return;

计算分最小分裂需要的样本数,这些涉及的值在Quinlan的论文中没有提到,可能也没有太多的道理,就是如果样本数的1/10小于m_minNoObj那么最小分裂样本数就是m_minNoObj,如果大于25,最小分裂样本数就是25

如果firstMiss小于2*minSplit表示已经不可以再分裂了(为什么刚才已经讲过了)。

// Compute values of criteria for all possible split indices.

defaultEnt = infoGainCrit.oldEnt(m_distribution);

while (next < firstMiss) {

if (trainInstances.instance(next - 1).value(m_attIndex)

+ 1e-5 < trainInstances.instance(next).value(m_attIndex)) {

// Move class values for all Instances up to next

// possible split point.

m_distribution.shiftRange(1, 0, trainInstances, last, next);

    // Check if enough Instances in each subset and compute

// values for criteria.

if (Utils.grOrEq(m_distribution.perBag(0), minSplit)

&& Utils.grOrEq(m_distribution.perBag(1), minSplit)) {

currentInfoGain = infoGainCrit.splitCritValue(

m_distribution, m_sumOfWeights, defaultEnt);

if (Utils.gr(currentInfoGain, m_infoGain)) {

m_infoGain = currentInfoGain;

splitIndex = next - 1;

}

m_index++;

}

last = next;

}

next++;

}

       oldEnt计算没有分裂的信息增益,得到defaultEnt注意,刚才是把样本放在了一个bag中。然后对所有有确定值的样本进行循环。第一个if,如果两个属性值太接近,那么选择的分裂点不会有太大的区别,就不进行处理。shiftRange是把第一个bag中下标从lastnext-1的样本移到第0bagshiftRange代码如下:

publicfinalvoid shiftRange(int from, int to, Instances source,

int startIndex, int lastPlusOne) throws Exception {

int classIndex;

double weight;

Instance instance;

int i;

for (i = startIndex; i < lastPlusOne; i++) {

instance = (Instance) source.instance(i);

classIndex = (int) instance.classValue();

weight = instance.weight();

m_perClassPerBag[from][classIndex] -= weight;

m_perClassPerBag[to][classIndex] += weight;

m_perBag[from] -= weight;

m_perBag[to] += weight;

}

}

很简单就是把对应样本的样本权重从from bag中减去,再加到to bag中。

转回来,如果bag 1bag 0都满足最小分裂样本数,计算在当前分裂点上的信息增益值。如果比上一个最好的分裂点的信息增益高,那么记录下当前的信息增益值为最高信息增益值m_infoGain,和当前分裂点splitIndex

// Was there any useful split?

if (m_index == 0)

return;

// Compute modified information gain for best split.

m_infoGain = m_infoGain - (Utils.log2(m_index) / m_sumOfWeights);

if (Utils.smOrEq(m_infoGain, 0))

return;

// Set instance variables' values to values for best split.

m_numSubsets = 2;

m_splitPoint = (trainInstances.instance(splitIndex + 1).value(

m_attIndex) + trainInstances.instance(splitIndex).value(

    m_attIndex)) / 2;

如果没有找到任何分裂点,返回,接下来的m_infoGain自己到J.R.QuinlanImproved use of continuous Attributes in C4.5论文中的第4页第二段中找。最后设置有两个结点,分裂点在刚才找到的最好的分裂点与下一个属性值的中点。

// In case we have a numerical precision problem we need to choose the

// smaller value

if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(

m_attIndex)) {

m_splitPoint = trainInstances.instance(splitIndex).value(

m_attIndex);

}

// Restore distributioN for best split.

m_distribution = new Distribution(2, trainInstances.numClasses());

m_distribution.addRange(0, trainInstances, 0, splitIndex + 1);

m_distribution.addRange(1, trainInstances, splitIndex + 1, firstMiss);

// Compute modified gain ratio for best split.

m_gainRatio = gainRatioCrit.splitCritValue(m_distribution,

m_sumOfWeights, m_infoGain);

       if是处理精度的细节问题。然后重新通过addRange计算m_distribution,最后计算增益率(Gain Ratio)

这里看到又有一个新类Distribution类,还是要把Distribution类讲一下,Distribution类中有一个bag成员变量,它的意思是能有几个子结点。从下面的构造函数看出来的,第一个参数在上面调用它的时候用的就是m_complexityIndex.

public Distribution(int numBags, int numClasses) {

int i;

m_perClassPerBag = newdouble[numBags][0];

m_perBag = newdouble[numBags];

m_perClass = newdouble[numClasses];

for (i = 0; i < numBags; i++)

m_perClassPerBag[i] = newdouble[numClasses];

totaL = 0;

}

Distributionadd函数就是在相应的属性值上进行统计,太简单了,略过。

回到刚才的buildTree函数,如果numSubsets返回1,则表示当前结点不再分裂为叶子结点,如果大于1,那么调用split函数,split函数只是根据有上次得到的子结点数,并根据WhichSubset返回值,把当前结点的样本分到几个子结点去。再对每一个子结点训练一个新子树,到这已经与以前讲的ID3有很大的相似了。

可能大家学习的时候都对理论很感兴趣,但看了半天也没看到,有点不解,其实也很好找,当然应该在handleEnumerateAttributehandleNumericAttribute中了,也就是InfoGainSplitCritGainRatioSplitCrit两个类。

分裂一个样本与NBTree相似,这里不再赘述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值