Weka算法Classifier-trees-RandomTree源码分析
一、RandomTree算法
在网上搜了一下,并没有找到RandomTree的严格意义上的算法描述,因此我觉得RandomTree充其量只是一种构建树的思路,和普通决策树相比,RandomTree会随机的选择若干属性来进行构建而不是选取所有的属性。
Weka在实现上,对于随机属性的选取、生成分裂点的过程是这样的:
1、设置一个要选取的属性的数量K
2、在全域属性中无放回的对属性进行抽样
3、算出该属性的信息增益(注意不是信息增益率)
4、重复K次,选出信息增益最大的当分裂节点。
5、构建该节点的孩子子树。
二、具体代码实现
(1)buildClassifier
- public void buildClassifier(Instances data) throws Exception {
- // 如果传入的K不合理,把K放到一个合理的范围里
- if (m_KValue > data.numAttributes() - 1)
- m_KValue = data.numAttributes() - 1;
- if (m_KValue < 1)
- m_KValue = (int) Utils.log2(data.numAttributes()) + 1;//这个是K的默认值
- // 判断一下该分类器是否有能力处理这个数据集,如果没能力直接就在testWithFail里抛异常退出了
- getCapabilities().testWithFail(data);
- // 删除掉missClass
- data = new Instances(data);
- data.deleteWithMissingClass();
- // 如果只有一列,就build一个ZeroR模型,之后就结束了。ZeroR模型分类是这样的:如果是连续型,总是返回期望,如果离散型,总是返回训练集中出现最多的那个
- if (data.numAttributes() == 1) {
- System.err
- .println("Cannot build model (only class attribute present in data!), "
- + "using ZeroR model instead!");
- m_zeroR = new weka.classifiers.rules.ZeroR();
- m_zeroR.buildClassifier(data);
- return;
- } else {
- m_zeroR = null;
- }
- // 如果m_NumFlods大于0,则会把数据集分为两部分,一部分用于train,一部分用于test,也就是backfit
- //分的方式和多折交叉验证是一样的,例如m_NumFlods是10的话,则train占90%,backfit占10%
- Instances train = null;
- Instances backfit = null;
- Random rand = data.getRandomNumberGenerator(m_randomSeed);
- if (m_NumFolds <= 0) {
- train = data;
- } else {
- data.randomize(rand);
- data.stratify(m_NumFolds);
- train = data.trainCV(m_NumFolds, 1, rand);
- backfit = data.testCV(m_NumFolds, 1);
- }
- // 生成所有的可选属性
- int[] attIndicesWindow = new int[data.numAttributes() - 1];
- int j = 0;
- for (int i = 0; i < attIndicesWindow.length; i++) {
- if (j == data.classIndex())
- j++; // 忽略掉classIndex
- attIndicesWindow[i] = j++;//这段代码有点奇怪,i和j是相等的,为啥不用attIndicesWindow=i?
- }
- // 算出每个class的频率,也就是每个分类出现的次数(更正确的说法应该是权重,但权重默认都是1)
- double[] classProbs = new double[train.numClasses()];
- for (int i = 0; i < train.numInstances(); i++) {
- Instance inst = train.instance(i);
- classProbs[(int) inst.classValue()] += inst.weight();
- }
- // Build tree
- m_Tree = new Tree();
- m_Info = new Instances(data, 0);
- m_Tree.buildTree(train, classProbs, attIndicesWindow, rand, 0);//调用tree的build方法,在后面单独分析
- // Backfit if required
- if (backfit != null) {
- m_Tree.backfitData(backfit);//在后面单独分析
- }
- }
这个Tree对象是RandomTree的一个子类,之前我还以为会复用其余的决策树模型(比如J48),但weka没这么做,很惊奇的是RandomTree和J48的作者还是同一个,不知道为啥这么设计。
(2)tree.buildTree
- protected void buildTree(Instances data, double[] classProbs,
- int[] attIndicesWindow, Random random, int depth) throws Exception {
- //首先判断一下是否有instance,如果没有的话直接就返回
- if (data.numInstances() == 0) {
- m_Attribute = -1;
- m_ClassDistribution = null;
- m_Prop = null;
- return;
- }
- m_ClassDistribution = classProbs.clone();
- if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum
- || Utils.eq(m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)],
- Utils.sum(m_ClassDistribution))
- || ((getMaxDepth() > 0) && (depth >= getMaxDepth()))) {
- // 递归结束的条件有3个 1、instance数量小于2*m_Minnum 2、instance都已经在同一个类中 3、达到最大的深度
- //前两个条件和j48的递归结束条件很相似,相关内容可参考我之前的几篇博客。
- m_Attribute = -1;
- m_Prop = null;
- return;
- }
- double val = -Double.MAX_VALUE;
- double split = -Double.MAX_VALUE;
- double[][] bestDists = null;
- double[] bestProps = null;
- int bestIndex = 0;
- double[][] props = new double[1][0];
- double[][][] dists = new double[1][0][0];//这个数组第一列只有下标为0的被用到,不知道为啥这么设计
- int attIndex = 0;//存储被选择到的属性
- int windowSize = attIndicesWindow.length;//存储目前可选择的属性的数量
- int k = m_KValue;//k代表还能选择的属性的数量
- boolean gainFound = false;//是否发现了一个有信息增益的节点
- while ((windowSize > 0) && (k-- > 0 || !gainFound)) {//此循环退出条件有2个 1、没有节点可以选了 2、已经选了k个属性了并且找到了一个有用的属性 换句话说,如果K次迭代没有找到可以分裂的随机节点,循环也会继续下去
- int chosenIndex = random.nextInt(windowSize);//随机选一个,生成下标
- attIndex = attIndicesWindow[chosenIndex];//得到该属性的index
- //下面三行把选择的属性放到attIndicesWindow的末尾,然后把windowSize-1这样下个循环就不会选到了,也就是实现了无放回的抽取
- attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize - 1];
- attIndicesWindow[windowSize - 1] = attIndex;
- windowSize--;
- double currSplit = distribution(props, dists, attIndex, data);//这个函数计算了在使用attIndex进行分裂所产生的分布,如果classIndex是连续值的话,还计算了分裂点,原理和J48的split一样,不在赘述。
- double currVal = gain(dists[0], priorVal(dists[0]));//这个计算了信息增益
- if (Utils.gr(currVal, 0))
- gainFound = true;//如果信息增益大于0的话,说明节点有效,设置gainFound
- if ((currVal > val) || ((currVal == val) && (attIndex < bestIndex))) {
- val = currVal; //如果信息增益大的话,则更新把attIndex更新为bestIndex,这是为了选取最优的节点(ID3)的方法
- bestIndex = attIndex;
- split = currSplit;
- bestProps = props[0];
- bestDists = dists[0];
- }
- }
- m_Attribute = bestIndex;
- // Any useful split found?
- if (Utils.gr(val, 0)) {
- pan style="white-space:pre"> </span>//如果找到了一个分裂点,则在该分裂点的基础上构建子树
- m_SplitPoint = split;
- m_Prop = bestProps;
- Instances[] subsets = splitData(data);
- m_Successors = new Tree[bestDists.length];
- for (int i = 0; i < bestDists.length; i++) {
- m_Successors[i] = new Tree();
- m_Successors[i].buildTree(subsets[i], bestDists[i], attIndicesWindow,
- random, depth + 1);//注意这里传入的attIndicesWindow没有变,换句话说,每次迭代传入的可选属性集合是一样的,因此子节点在进行属性的random选择时,很有可能会选择到父节点已经选过的节点,但因为不产生信息增益,因此不会再次作为bestIndex,但会产生额外的计算量(我感觉还不少),这里还有一定的优化空间,同理j48也是这么实现的。
- }
- boolean emptySuccessor = false;
- for (int i = 0; i < subsets.length; i++) {
- if (m_Successors[i].m_ClassDistribution == null) {
- emptySuccessor = true;
- break;
- }
- }
- if (!emptySuccessor) {
- m_ClassDistribution = null;
- }
- } else {
- //这个else是<span style="font-family: Arial, Helvetica, sans-serif;">Utils.gr(currVal, 0)这个条件的,代表没有选择到合适的分裂节点</span>
- m_Attribute = -1;
- }
(3)tree.backfit
什么是Backfit?Backfit将改变已有tree节点及其子节点的class分布,而class分布将直接被用于实例的预测。
直接使用RandomTree有时会出现过拟合的现象(通过代码可以看到,和J48相比没有剪枝过程),因此通过传入一个新的数据集来backfit已有节点是一个解决过拟合的方法。
- protected void backfitData(Instances data, double[] classProbs)
- throws Exception {
- <span style="white-space:pre"> </span>//判断一下是否有数据
- if (data.numInstances() == 0) {
- m_Attribute = -1;
- m_ClassDistribution = null;
- m_Prop = null;
- return;
- }
- m_ClassDistribution = classProbs.clone();
- if (m_Attribute > -1) {
- // m_Attribut>-1代表不是leaf,可以看上面的buildTree得出这个结论
- m_Prop = new double[m_Successors.length];//子节点数组的length也就是分类的类的数量
- <span style="white-space:pre"> </span>//把传入的data用此节点算各类的频率
- for (int i = 0; i < data.numInstances(); i++) {
- Instance inst = data.instance(i);
- if (!inst.isMissing(m_Attribute)) {
- if (data.attribute(m_Attribute).isNominal()) {
- m_Prop[(int) inst.value(m_Attribute)] += inst.weight();
- } else {
- m_Prop[(inst.value(m_Attribute) < m_SplitPoint) ? 0 : 1] += inst
- .weight();//连续型只会分两类,小于splitPoint一类,大于是一类,和J48采用的策略相同
- }
- }
- }
- if (Utils.sum(m_Prop) <= 0) {
- m_Attribute = -1;//如果data全部都是missingValue,则把此节点变成leaf节点
- m_Prop = null;
- return;
- }
- // 归一化
- Utils.normalize(m_Prop);
- // 根据本节点算出在data上进行分类的subset
- Instances[] subsets = splitData(data);
- for (int i = 0; i < subsets.length; i++) {
- // 递归的对孩子节点进行backfit
- double[] dist = new double[data.numClasses()];
- for (int j = 0; j < subsets[i].numInstances(); j++) {
- dist[(int) subsets[i].instance(j).classValue()] += subsets[i]
- .instance(j).weight();
- }
- m_Successors[i].backfitData(subsets[i], dist);
- }
- <span style="white-space:pre"> </span>
- if (getAllowUnclassifiedInstances()) {
- m_ClassDistribution = null;
- return;
- }
- <span style="white-space:pre"> </span>//如果某个子节点的分布为空的话,则父节点要保存分布,否则不需要持有分布。
- <span style="white-space:pre"> </span>//为什么呢?因为使用RandomTree进行预测时会遍历节点的分布并进行累加,得到分布最大的class作为预测class,在J48的那篇博客中有分析
- boolean emptySuccessor = false;
- for (int i = 0; i < subsets.length; i++) {
- if (m_Successors[i].m_ClassDistribution == null) {
- emptySuccessor = true;
- return;
- }
- }
- m_ClassDistribution = null;
- }
- }
-
三、总结
对RandomForest的分析到这里就结束了,首先分析了RandomForest,接着分析了Bagging,最后分析了RandomTree。