Weka学习 :ADTree

ADTree中主要有两种节点,一种是PreditionNode,一种是SplitNode。weka实现中就对应定义了这两个数据结构。

	public class PredictionNode
	{	
	double value;
	FastVector children;
	}



value存 a或者b(具体含义请看论文)。children存SplitNode.

	public abstract class Splitter
	{
    	public int orderAdded;
	/**
	**	还有其他一些抽象函数,这里暂且省略
	**/
	}


作者实现了Splitter的两个子类,分别是TwoWayNominalSplit和TwoWayNumericSplit。具体的结构等到用的时候再说。

主类自然是ADTree。主函数入口

public class ADTree
{
	protected Instances m_trainInstances;

  /** The root of the tree */
  protected PredictionNode m_root = null;

  /** The number of the last splitter added to the tree */
  protected int m_lastAddedSplitNum = 0;

  /** An array containing the indices to the numeric attributes in the data */
  protected int[] m_numericAttIndices;

  /** An array containing the indices to the nominal attributes in the data */
  protected int[] m_nominalAttIndices;

  /** The total weight of the instances - used to speed Z calculations */
  protected double m_trainTotalWeight;

  /** The training instances with positive class - referencing the training dataset */
  protected ReferenceInstances m_posTrainInstances;

  /** The training instances with negative class - referencing the training dataset */
  protected ReferenceInstances m_negTrainInstances;

  /** The best node to insert under, as found so far by the latest search */
  protected PredictionNode m_search_bestInsertionNode;

  /** The best splitter to insert, as found so far by the latest search */
  protected Splitter m_search_bestSplitter;

  /** The smallest Z value found so far by the latest search */
  protected double m_search_smallestZ;

  /** The positive instances that apply to the best path found so far */
  protected Instances m_search_bestPathPosInstances;

  /** The negative instances that apply to the best path found so far */
  protected Instances m_search_bestPathNegInstances;
  
}
这里列出最重要的成员变量,作者都给了解释,读者可先熟悉一下,等分析函数的时候再做说明。

下面是主函数入口

 public void buildClassifier(Instances instances) throws Exception {

  
	/**
	**初始化m_trainInstances,m_posTrainInstances, 
	m_negTrainInstances,m_root,m_numericAttIndices,
	m_nominalAttIndices,m_trainTotalWeight
	**/
    initClassifier(instances);

    // 典型的AdaBoost 算法
    for (int T = 0; T < m_boostingIterations; T++) boost();


  }
  
initClassifier(instances)函数不打算贴出来了,完成的功能就是注释里的功能,读者可以自行查阅一下。这里分析核心函数 boost()
  public void boost() throws Exception {

    if (m_trainInstances == null || m_trainInstances.numInstances() == 0)
      throw new Exception("Trying to boost with no training data");

    /**
     * 迭代入口。真正选取splitNode的地方。即赋值成员变量m_search_bestSplitter,m_search_bestInsertionNode

     */
    searchForBestTestSingle();

    if (m_search_bestSplitter == null) return; // handle empty instances

    /**
     * 根据m_search_bestSplitter生成两个PreditionNode子节点。
     */
    for (int i=0; i<2; i++) {
      Instances posInstances =
	m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathPosInstances);
      Instances negInstances =
	m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathNegInstances);
      double predictionValue = calcPredictionValue(posInstances, negInstances);
      PredictionNode newPredictor = new PredictionNode(predictionValue);
      updateWeights(posInstances, negInstances, predictionValue);
      m_search_bestSplitter.setChildForBranch(i, newPredictor);
    }

    /**
     * 将生成的m_search_bestSplitter插入ADTree.这里 m_search_bestInsertionNode的作用体现了.
     * m_search_bestInsertionNode保存m_search_bestSplitter的插入点。保证即使是迭代也不会丢失插入点
     */
    m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter, this);

    // free memory
    m_search_bestPathPosInstances = null;
    m_search_bestPathNegInstances = null;
    m_search_bestSplitter = null;
  }
boost()算法总结起来就是:计算最小的Z()(Z()的值是需要迭代的),根据最小的Z生成SplitNode及对应的两个PreditionNode。那么接下来就看searchForBestTestSingle()

注:这里我还没有看处理连续数值型的数据,只看了离散型的。等看了数值型的再补充。

private void searchForBestTestSingle(PredictionNode currentNode,
				       Instances posInstances, Instances negInstances)
    throws Exception {

    // don't investigate pure or empty nodes any further

    if (posInstances.numInstances() == 0 || negInstances.numInstances() == 0) return;

    // do z-pure cutoff
	  /**
	   * 这个公式我在论文里没找到...
	   */
    if (calcZpure(posInstances, negInstances) >= m_search_smallestZ) return;

    /**
     * 可以无视这两句,记录状态用的
     */
    m_nodesExpanded++;
    m_examplesCounted += posInstances.numInstances() + negInstances.numInstances();

    // evaluate static splitters (nominal)
    /**
     * 针对每个属性计算Z-value
     */
    for (int i=0; i<m_nominalAttIndices.length; i++)
      evaluateNominalSplitSingle(m_nominalAttIndices[i], currentNode,
				 posInstances, negInstances);

    // evaluate dynamic splitters (numeric)
    if (m_numericAttIndices.length > 0) {

      // merge the two sets of instances into one
      Instances allInstances = new Instances(posInstances);
      for (Enumeration e = negInstances.enumerateInstances(); e.hasMoreElements(); )
	allInstances.add((Instance) e.nextElement());
    
      // use method of finding the optimal Z split-point
      for (int i=0; i<m_numericAttIndices.length; i++)
	evaluateNumericSplitSingle(m_numericAttIndices[i], currentNode,
				   posInstances, negInstances, allInstances);
    }
    /**
     * 嵌套调用的返回点,如果该PredictionNode没有子节点,则返回,如果有,
     * 则继续计算子节点(是个SplitNode)的两个分支(两个PredctionNode)对应的Z-Value.如果子节点还有子节点,就继续嵌套调用
     * 通过这里可以看出,计算Z-Value是整个树所有的分支处比较。有可能上面的层比下面的层的Z-value要小,
     * 则会导致上一层再增加一个分支,因此会导致多叉树,也可能会导致对同一个属性两次判断。
     */
    if (currentNode.getChildren().size() == 0) return;

    // keep searching
    switch (m_searchPath) {
    case SEARCHPATH_ALL:
      goDownAllPathsSingle(currentNode, posInstances, negInstances);
      break;
    case SEARCHPATH_HEAVIEST: 
      goDownHeaviestPathSingle(currentNode, posInstances, negInstances);
      break;
    case SEARCHPATH_ZPURE: 
      goDownZpurePathSingle(currentNode, posInstances, negInstances);
      break;
    case SEARCHPATH_RANDOM: 
      goDownRandomPathSingle(currentNode, posInstances, negInstances);
      break;
    }
  }
看论文得知,算法的过程是针对所有的Prediction 以及所有的SplitNode 都计算一遍Z-value,反映到树上来说,就是对于每一个SplitNode,都计算一遍Z-value。因此当前的最小值Z-value必须要保存为全局的。这就是上述 

 protected PredictionNode m_search_bestInsertionNode;
  protected Splitter m_search_bestSplitter;
  protected double m_search_smallestZ;

的作用。

 private void evaluateNominalSplitSingle(int attIndex, PredictionNode currentNode,
					  Instances posInstances, Instances negInstances)
  {
    
    double[] indexAndZ = findLowestZNominalSplit(posInstances, negInstances, attIndex);

    if (indexAndZ[1] < m_search_smallestZ) {
      m_search_smallestZ = indexAndZ[1];
      m_search_bestInsertionNode = currentNode;
      m_search_bestSplitter = new TwoWayNominalSplit(attIndex, (int) indexAndZ[0]);
      m_search_bestPathPosInstances = posInstances;
      m_search_bestPathNegInstances = negInstances;
    }
  }
这个函数对某个属性,所有的属性值都计算一遍Z-value ,findLowestZNominalSplit(posInstances, negInstances, attIndex)就是逐一利用论文所给公式,选取最小的。这里就不贴源码了,很简单的一个函数。

最终,IndexAndZ[0]存放属性值的索引,IndexAndZ[1]存放最小的Z-value。后面的if判断针对这个属性计算的最小Z-value是否小于当前的最小值。如果小就更新。用三个全局变量保存现场。

switch语句对应了四个函数,四个函数不同点在于对于下一层PredtionNode的选取。下面逐一贴出四个函数代码。

 private void goDownAllPathsSingle(PredictionNode currentNode,
				    Instances posInstances, Instances negInstances)
    throws Exception {

    for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
      Splitter split = (Splitter) e.nextElement();
      for (int i=0; i<split.getNumOfBranches(); i++)
	searchForBestTestSingle(split.getChildForBranch(i),
				split.instancesDownBranch(i, posInstances),
				split.instancesDownBranch(i, negInstances));
    }
  }
这个最简单,反映在树上就是对所有的节点都计算Z-value.(今天先写到这里,后面三个函数也很好理解)


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值