ADTree中主要有两种节点,一种是PreditionNode,一种是SplitNode。weka实现中就对应定义了这两个数据结构。
public class PredictionNode
{
double value;
FastVector children;
}
value存 a或者b(具体含义请看论文)。children存SplitNode.
public abstract class Splitter
{
public int orderAdded;
/**
** 还有其他一些抽象函数,这里暂且省略
**/
}
作者实现了Splitter的两个子类,分别是TwoWayNominalSplit和TwoWayNumericSplit。具体的结构等到用的时候再说。
主类自然是ADTree。主函数入口
public class ADTree
{
protected Instances m_trainInstances;
/** The root of the tree */
protected PredictionNode m_root = null;
/** The number of the last splitter added to the tree */
protected int m_lastAddedSplitNum = 0;
/** An array containing the indices to the numeric attributes in the data */
protected int[] m_numericAttIndices;
/** An array containing the indices to the nominal attributes in the data */
protected int[] m_nominalAttIndices;
/** The total weight of the instances - used to speed Z calculations */
protected double m_trainTotalWeight;
/** The training instances with positive class - referencing the training dataset */
protected ReferenceInstances m_posTrainInstances;
/** The training instances with negative class - referencing the training dataset */
protected ReferenceInstances m_negTrainInstances;
/** The best node to insert under, as found so far by the latest search */
protected PredictionNode m_search_bestInsertionNode;
/** The best splitter to insert, as found so far by the latest search */
protected Splitter m_search_bestSplitter;
/** The smallest Z value found so far by the latest search */
protected double m_search_smallestZ;
/** The positive instances that apply to the best path found so far */
protected Instances m_search_bestPathPosInstances;
/** The negative instances that apply to the best path found so far */
protected Instances m_search_bestPathNegInstances;
}
这里列出最重要的成员变量,作者都给了解释,读者可先熟悉一下,等分析函数的时候再做说明。
下面是主函数入口
public void buildClassifier(Instances instances) throws Exception {
/**
**初始化m_trainInstances,m_posTrainInstances,
m_negTrainInstances,m_root,m_numericAttIndices,
m_nominalAttIndices,m_trainTotalWeight
**/
initClassifier(instances);
// 典型的AdaBoost 算法
for (int T = 0; T < m_boostingIterations; T++) boost();
}
initClassifier(instances)函数不打算贴出来了,完成的功能就是注释里的功能,读者可以自行查阅一下。这里分析核心函数 boost()
public void boost() throws Exception {
if (m_trainInstances == null || m_trainInstances.numInstances() == 0)
throw new Exception("Trying to boost with no training data");
/**
* 迭代入口。真正选取splitNode的地方。即赋值成员变量m_search_bestSplitter,m_search_bestInsertionNode
*/
searchForBestTestSingle();
if (m_search_bestSplitter == null) return; // handle empty instances
/**
* 根据m_search_bestSplitter生成两个PreditionNode子节点。
*/
for (int i=0; i<2; i++) {
Instances posInstances =
m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathPosInstances);
Instances negInstances =
m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathNegInstances);
double predictionValue = calcPredictionValue(posInstances, negInstances);
PredictionNode newPredictor = new PredictionNode(predictionValue);
updateWeights(posInstances, negInstances, predictionValue);
m_search_bestSplitter.setChildForBranch(i, newPredictor);
}
/**
* 将生成的m_search_bestSplitter插入ADTree.这里 m_search_bestInsertionNode的作用体现了.
* m_search_bestInsertionNode保存m_search_bestSplitter的插入点。保证即使是迭代也不会丢失插入点
*/
m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter, this);
// free memory
m_search_bestPathPosInstances = null;
m_search_bestPathNegInstances = null;
m_search_bestSplitter = null;
}
boost()算法总结起来就是:计算最小的Z()(Z()的值是需要迭代的),根据最小的Z生成SplitNode及对应的两个PreditionNode。那么接下来就看searchForBestTestSingle()
注:这里我还没有看处理连续数值型的数据,只看了离散型的。等看了数值型的再补充。
private void searchForBestTestSingle(PredictionNode currentNode,
Instances posInstances, Instances negInstances)
throws Exception {
// don't investigate pure or empty nodes any further
if (posInstances.numInstances() == 0 || negInstances.numInstances() == 0) return;
// do z-pure cutoff
/**
* 这个公式我在论文里没找到...
*/
if (calcZpure(posInstances, negInstances) >= m_search_smallestZ) return;
/**
* 可以无视这两句,记录状态用的
*/
m_nodesExpanded++;
m_examplesCounted += posInstances.numInstances() + negInstances.numInstances();
// evaluate static splitters (nominal)
/**
* 针对每个属性计算Z-value
*/
for (int i=0; i<m_nominalAttIndices.length; i++)
evaluateNominalSplitSingle(m_nominalAttIndices[i], currentNode,
posInstances, negInstances);
// evaluate dynamic splitters (numeric)
if (m_numericAttIndices.length > 0) {
// merge the two sets of instances into one
Instances allInstances = new Instances(posInstances);
for (Enumeration e = negInstances.enumerateInstances(); e.hasMoreElements(); )
allInstances.add((Instance) e.nextElement());
// use method of finding the optimal Z split-point
for (int i=0; i<m_numericAttIndices.length; i++)
evaluateNumericSplitSingle(m_numericAttIndices[i], currentNode,
posInstances, negInstances, allInstances);
}
/**
* 嵌套调用的返回点,如果该PredictionNode没有子节点,则返回,如果有,
* 则继续计算子节点(是个SplitNode)的两个分支(两个PredctionNode)对应的Z-Value.如果子节点还有子节点,就继续嵌套调用
* 通过这里可以看出,计算Z-Value是整个树所有的分支处比较。有可能上面的层比下面的层的Z-value要小,
* 则会导致上一层再增加一个分支,因此会导致多叉树,也可能会导致对同一个属性两次判断。
*/
if (currentNode.getChildren().size() == 0) return;
// keep searching
switch (m_searchPath) {
case SEARCHPATH_ALL:
goDownAllPathsSingle(currentNode, posInstances, negInstances);
break;
case SEARCHPATH_HEAVIEST:
goDownHeaviestPathSingle(currentNode, posInstances, negInstances);
break;
case SEARCHPATH_ZPURE:
goDownZpurePathSingle(currentNode, posInstances, negInstances);
break;
case SEARCHPATH_RANDOM:
goDownRandomPathSingle(currentNode, posInstances, negInstances);
break;
}
}
看论文得知,算法的过程是针对所有的Prediction 以及所有的SplitNode 都计算一遍Z-value,反映到树上来说,就是对于每一个SplitNode,都计算一遍Z-value。因此当前的最小值Z-value必须要保存为全局的。这就是上述
protected PredictionNode m_search_bestInsertionNode;
protected Splitter m_search_bestSplitter;
protected double m_search_smallestZ;
的作用。
private void evaluateNominalSplitSingle(int attIndex, PredictionNode currentNode,
Instances posInstances, Instances negInstances)
{
double[] indexAndZ = findLowestZNominalSplit(posInstances, negInstances, attIndex);
if (indexAndZ[1] < m_search_smallestZ) {
m_search_smallestZ = indexAndZ[1];
m_search_bestInsertionNode = currentNode;
m_search_bestSplitter = new TwoWayNominalSplit(attIndex, (int) indexAndZ[0]);
m_search_bestPathPosInstances = posInstances;
m_search_bestPathNegInstances = negInstances;
}
}
这个函数对某个属性,所有的属性值都计算一遍Z-value ,findLowestZNominalSplit(posInstances, negInstances, attIndex)就是逐一利用论文所给公式,选取最小的。这里就不贴源码了,很简单的一个函数。
最终,IndexAndZ[0]存放属性值的索引,IndexAndZ[1]存放最小的Z-value。后面的if判断针对这个属性计算的最小Z-value是否小于当前的最小值。如果小就更新。用三个全局变量保存现场。
switch语句对应了四个函数,四个函数不同点在于对于下一层PredtionNode的选取。下面逐一贴出四个函数代码。
private void goDownAllPathsSingle(PredictionNode currentNode,
Instances posInstances, Instances negInstances)
throws Exception {
for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
for (int i=0; i<split.getNumOfBranches(); i++)
searchForBestTestSingle(split.getChildForBranch(i),
split.instancesDownBranch(i, posInstances),
split.instancesDownBranch(i, negInstances));
}
}
这个最简单,反映在树上就是对所有的节点都计算Z-value.(今天先写到这里,后面三个函数也很好理解)