spark ml 随机森林源码笔记三

二已经很长了,告一段路,从三开始真正构建决策森林,首先创建缓存节点id的RDD,让所有点属于跟节点

    val nodeIdCache = if (strategy.useNodeIdCache) {
      Some(NodeIdCache.init(
        data = baggedInput,
        numTrees = numTrees,
        checkpointInterval = strategy.checkpointInterval,
        initVal = 1))
    } else {
      None
    }

引出一个类NodeIdCache,init后创建了一个nodeIdCache类,有两个成员,一个是Array.fill[Int](numTrees)(initVal),也就是每条数据对应一个初始数组,长度是numTrees,元素是1,因为最终每个点只属于每棵树的一个节点,通过元素索引来表示,另一个成员是checkpointInterval,还有如下几个私有变量

private[spark] class NodeIdCache(
  var nodeIdsForInstances: RDD[Array[Int]],
  val checkpointInterval: Int) extends Logging {
  private var prevNodeIdsForInstances: RDD[Array[Int]] = null//记录上一次更新数据对应每棵树节点索引
  private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()//因为是多次更新,搞一个queue记录更新前后关系
  private var rddUpdateCount = 0//更新的次数,满足checkpointInterval次数就checkpoint
  private val canCheckpoint = nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty//检查Checkpoint目录是否存在,是否可以Checkpoint
  private val fs = FileSystem.get(nodeIdsForInstances.sparkContext.hadoopConfiguration)//获取文件系统实例,以进行删除Checkpoint操作

类的方法在主类中都会用到,还是按照主类顺序介绍

    // FIFO queue of nodes to train: (treeIndex, node)  
    val nodeQueue = new mutable.Queue[(Int, LearningNode)]()//搞一个队列,记录树索引和节点索引

    // Allocate and queue root nodes.
    val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))//构建索引为1即根节点的空节点的数组,长度为树的棵数
    Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))//给nodeQueue赋值,树的索引,初始跟节点

下面就是最核心的方法了

    while (nodeQueue.nonEmpty) {//循环刚才的队列直到空,这个为空就说明已经把森林建好了
      // Collect some nodes to split, and choose features for each node (if subsampling).//注释的意思是找一些点来划分,如果属性采样就选择属性,每组点可能来自不同的树,不同层级的节点,核心还是找一些节点来划分,不是按树的分裂顺序划分,这个和传统机器学习决策树是不同的,这里还是很抽象,我们继续看selectNodesToSplit
      // Each group of nodes may come from one or multiple trees, and at multiple levels.
      val (nodesForGroup, treeToNodeToIndexInfo) =
        RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
      // Sanity check (should never occur):
      assert(nodesForGroup.nonEmpty,
        s"RandomForest selected empty nodesForGroup.  Error for unknown reason.")


      // Choose node splits, and enqueue new nodes as needed.
      timer.start("findBestSplits")
      RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
        treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache)
      timer.stop("findBestSplits")
    }

selectNodesToSplit从名字上看就是如何选节点,注释说在属性采样的情况下,跟踪聚合所占内存,如果需要过多内存就停止增加节点

  private[tree] class NodeIndexInfo(
      val nodeIndexInGroup: Int,
      val featureSubset: Option[Array[Int]]) extends Serializable//NodeIndexInfo是个类,里面有节点在组中的索引,节点属性子集

  /**
   * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
   * This tracks the memory usage for aggregates and stops adding nodes when too much memory
   * will be needed; this allows an adaptive number of nodes since different nodes may require
   * different amounts of memory (if featureSubsetStrategy is not "all").
   *
   * @param nodeQueue  Queue of nodes to split.
   * @param maxMemoryUsage  Bound on size of aggregate statistics.
   * @return  (nodesForGroup, treeToNodeToIndexInfo).//返回值和下面的数据结构还是比较抽象的,我们先简单翻译一下,大概有个印象,后面再看是如何用的,千万别死盯着注释,由于翻译和信息传递问题,注释有时候和代码表达的意思不太一致
   *          nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.//nodesForGroup 是个数组,记录树索引对应的节点索引
   *
   *          treeToNodeToIndexInfo holds indices selected features for each node:
   *            treeIndex --> (global) node index --> (node index in group, feature indices).//这是三级索引,一级树索引,二级节点在树中的索引,节点在组中的索引与特征索引组成的元组,说实话,有时候看别人的数据结构比自己设计数据结构还难
   *          The (global) node index is the index in the tree; the node index in group is the//节点在树中的索引已经解释了,节点在组中的索引说明有(0,numNodesInGroup)这么一个数组,现在还不知道是干什么的,后面才能看清楚
   *           index in [0, numNodesInGroup) of the node in this group.
   *          The feature indices are None if not subsampling features.//属性索引没啥好说的
   */
  private[tree] def selectNodesToSplit(
      nodeQueue: mutable.Queue[(Int, LearningNode)],
      maxMemoryUsage: Long,
      metadata: DecisionTreeMetadata,
      rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
    // Collect some nodes to split:
    //  nodesForGroup(treeIndex) = nodes to split
    val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]()//mutableNodesForGroup:每棵树要划分的若干节点
    val mutableTreeToNodeToIndexInfo =
      new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()//构造了第二个输出
    var memUsage: Long = 0L
    var numNodesInGroup = 0
    while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {//如果队列不为空且内存够用做循环
      val (treeIndex, node) = nodeQueue.head //取出第一个树的根节点
      // Choose subset of features for node (if subsampling).
      val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
        Some(SamplingUtils.reservoirSampleAndCount(Range(0,   //reservoirSampleAndCount这个方法别害怕,其实就是属性抽样,这里有个重要的事实就是每个节点的属性是随机的,而不是没棵树的属性是随机的
          metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1)
      } else {
        None
      }
      // Check if enough memory remains to add this node to the group.
      val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L//aggregateSizeForNode是个方法,跟之前一样,如果是分类结果是分类数乘以总装箱数,如果是回归就是3倍总装箱数,nodeMemUsage 是需要使用的内存
      if (memUsage + nodeMemUsage <= maxMemoryUsage) { //如果内存够用
        nodeQueue.dequeue()//队列弹出一棵树的一个元素
        mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
          node//如果mutableNodesForGroup已有这个树的索引,就返回这个索引对应的节点数组,如果没有这个树的索引,创建这个树索引对应的节点数组,最后都有把出栈的节点放到数组中
        mutableTreeToNodeToIndexInfo
          .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
          = new NodeIndexInfo(numNodesInGroup, featureSubset)//类似,mutableTreeToNodeToIndexInfo已有这个树的索引就返回这个索引对应的map,否则创建这个树索引对应的map,最后都把NodeIndexInfo放到数组中,这里map的key是node.id即节点id,value是NodeIndexInfo,包含,这个节点在组中的索引和属性集两个成员
      }
      numNodesInGroup += 1  //组元素自增
      memUsage += nodeMemUsage//累加使用的内存

//到这里我们看到所谓的组就是满足while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {这个条件所能放的元素,构成一组
    }

    // Convert mutable maps to immutable ones.
    val nodesForGroup: Map[Int, Array[LearningNode]] =
      mutableNodesForGroup.mapValues(_.toArray).toMap
    val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
    (nodesForGroup, treeToNodeToIndexInfo)//输出结果
  }

selectNodesToSplit完了后是findBestSplits,是划分节点的核心方法,先看看参数

  private[tree] def findBestSplits(
      input: RDD[BaggedPoint[TreePoint]],//输入数据
      metadata: DecisionTreeMetadata,//输入数据元数据
      topNodes: Array[LearningNode],//各棵树的根节点,用于数据和节点匹配
      nodesForGroup: Map[Int, Array[LearningNode]],//上一个方法得到的数据结构,记录树索引和要划分节点的map
      treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],//上一个方法得到的数据结构,记录树索引,节点索引到这个节点信息的二级map
      splits: Array[Array[Split]],//记录所有划分,是个二维数据
      nodeQueue: mutable.Queue[(Int, LearningNode)],//由树索引和节点组成的队列,注释说生成的非叶节点会更新这个队列,后面才能看出是不是这样,另外非常重要的一点是nodeQueue是队列,selectNodesToSplit操作了这个队列,findBestSplits再使用时时已不是最初始的那个队列了
      timer: TimeTracker = new TimeTracker,//这个没啥可说的
      nodeIdCache: Option[NodeIdCache] = None   //是包含数据和每棵树的节点对应关系的数据结构  ): Unit = {

后面注释还说了一下这样做的原因

   /*
     * The high-level descriptions of the best split optimizations are noted here.
     *
     * *Group-wise training* //按组训练可以减少数据传递,每次迭代需要更多计算和存储,但节省了迭代次数 
     * We perform bin calculations for groups of nodes to reduce the number of     
     * passes over the data.  Each iteration requires more computation and storage,
     * but saves several iterations over the data.
     *
     * *Bin-wise computation*//使用装箱计算而不是直接计算最优划分,不是考虑每个样本归左子树还是右子树导致的每个划分不纯性的变化,而是把每个样本的每个属性值装箱,我们利用这个结构去计算装箱累计值并使用这个累计值计算划分的信息增益
     * We use a bin-wise best split computation strategy instead of a straightforward best split
     * computation strategy. Instead of analyzing each sample for contribution to the left/right
     * child node impurity of every split, we first categorize each feature of a sample into a
     * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
     * to calculate information gain for each split.
     *
     * *Aggregation over partitions*//分区聚合而不是用flatMap/reduceByKey算子,由于预先知道装箱数量,用数组存储聚合结果,通过rdd的聚合方法大幅减少通信
     * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
     * the number of splits in advance. Thus, we store the aggregates (at the appropriate
     * indices) in a single array for all bins and rely upon the RDD aggregate method to
     * drastically reduce the communication overhead.
     */

val numNodes = nodesForGroup.values.map(_.length).sum//组内节点数

    val nodes = new Array[LearningNode](numNodes)//  根据nodesForGroup构建组节点数组,也就是nodesForGroup里的nodeIndexInGroup就是nodes 的索引
    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
      nodesForTree.foreach { node =>
        nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
      }
    }

    // In each partition, iterate all instances and compute aggregate stats for each node,//还是先简单解释下注释,简单了解下后面的工作,每个分区,迭代所有实例并为每个节点计算聚合信息,为每个节点构建一个节点信息和节点聚合信息的元组,通过reduceByKey操作,节点统计shuffled到一个分区并聚合,这样就找到了节点最优划分,最终最优划分被收集到driver上,就形成了决策树,听起来还是比较模糊的,还得一行一行看才能明白
    // yield an (nodeIndex, nodeAggregateStats) pair for each node.
    // After a `reduceByKey` operation,
    // stats of a node will be shuffled to a particular partition and be combined together,
    // then best splits for nodes are found there.
    // Finally, only best Splits for nodes are collected to driver to construct decision tree.

    val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)//这个方法获取组索引和属性集索引的映射关系并广播
    val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)

    val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {//计算分区聚合,返回的就是注释说的节点索引和节点聚合信息构成的元组,而DTStatsAggregator是包含元数据和属性集的类
      input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>//如果有nodeIdCache,通过zip把数据和数据对应各树节点id的数组结合
        // Construct a nodeStatsAggregators array to hold node aggregate stats,
        // each node will have a nodeStatsAggregator  //每个节点都对应一个nodeStatsAggregator
        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>  //tabluate是个常用的算子,功能是用第二个方法参数作用从0开始的n个元素,这里就是获得组节点对应的节点聚合数组
          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
            Some(nodeToFeatures(nodeIndex))
          }
          new DTStatsAggregator(metadata, featuresForNode)
        }


        // iterator all instances in current partition and update aggregate stats //不断更新聚合信息
        points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))   //binSeqOpWithNodeIdCache这个方法会引出很多方法,但做的事没没有那么复杂,points是一个数组,做循环的结果是把每个元素的信息都聚合到nodeStatsAggregators,注意这里是每个partition聚合成一个nodeStatsAggregators,又印证了注释,第二个参数_是points元素的占位符,聚合是以每个树每个节点每个属性每个装箱每个数据的粒度开始,如果不是unorderedFeatures属性,通过orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)处理,否则通过mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,metadata.unorderedFeatures, instanceWeight, featuresForNode)处理

先说一下orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode),从名字看属性都是有序属性装箱操作,里面最终会调用agg.update(featureIndexIdx, binIndex, label, instanceWeight),agg就是以节点组为长度的数组中的一个元素,更新以后返回的还是agg自己,featureIndexIdx是特征索引,binIndex是这个特征下的装箱索引,label是样本label值,instanceWeight是抽样后这棵树有多少个这样的样本,

里面第一个方法是val i = featureOffsets(featureIndex) + binIndex * statsSize,而featureOffsets又是通过numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)计算,scanLeft是个很重要的算子,第一个参数是初始值,第二个参数是函数,函数第一个参数是初始值,第二个参数是数组的元素,计算结果就是新数组的元素,对于回归来说statsSize是用于统计bin的向量长度,是3,这也就解释了计算节点内存为什么是3*totalBins,3就是向量长度,回到scanLeft,scanLeft结果就是每个属性totalBins的偏移,0感觉没什么毛子用,这样i就是属性totalBins的偏移+bin索引的偏移。

第二个方法是impurityAggregator.update(allStats, i, label, instanceWeight),allStats是Array[Double](allStatsSize),而allStatsSize是featureOffsets.last即featureOffsets最后一个元素开始的索引位置,下面是更新函数,这里再简单总结下,allStats是一个一维数组,但是是由两级索引构成的,第一级是属性索引,第二级是这个属性的bin索引,相当于放到一个长数组里,而allStats的声明长度是featureOffsets的最后一个属性的第一个bin位置索引的长度,并不是实际需要的长度,由于数组作为参数,只需要第一个bin位置即可,后面的累计还是会增加这个索引的

def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {//这下你就看清楚每个bin索引对应的3个元素里是什么了,是不是瞬间从懵B到豁然开朗
    allStats(offset) += instanceWeight //第一个元素权重
    allStats(offset + 1) += instanceWeight * label//第二个权重乘以label值
    allStats(offset + 2) += instanceWeight * label * label//第三个权重乘以label平方
  }

然后是mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,metadata.unorderedFeatures, instanceWeight, featuresForNode),从名字看属性是混合型的,即既有ordered,又有unordered,ordered和orderedBinSeqOp的处理类似,unordered会分成左装箱或右装箱,先介绍下参数agg存聚合值的数组,aggNodeIndex是组索引值,baggedPoint.datum是样本,splits是各属性的划分索引,metadata.unorderedFeatures是unordered属性的索引集合,instanceWeight是样本在各树中的采样数量可以理解为权重,featuresForNode是组节点的属性索引数组。混合属性如果是ordered那么处理方式和orderedBinSeqOp一致,如果是unordered,先计算 val (leftNodeFeatureOffset, rightNodeFeatureOffset) =agg.getLeftRightFeatureOffsets(featureIndexIdx),getLeftRightFeatureOffsets从名字来看是左右特征偏移,参数是特征索引,最终

leftNodeFeatureOffset=featureOffsets(featureIndex)//属性偏移的开始位置作为左节点属性偏移

rightNodeFeatureOffset=featureOffsets(featureIndex)+(numBins(featureIndex) >> 1) * statsSize//属性偏移的开始位置+装箱的右位移也就是一半*statsSize作为右节点属性偏移,跟之前的原理一样

里面最核心的一段是

        while (splitIndex < numSplits) {//对某个属性,一直循环知道划分数
          if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {//featureSplits(splitIndex)是splitIndex对应的split类,featureValue一定别忘了是10010这种,因为是unordered属性,featureSplits是split类的数组,shouldGoLeft有点意思,贴出相关代码

  override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = {
    if (isLeft) {// private val isLeft: Boolean = _leftCategories.length <= numCategories / 2  对某个属性, 如果数组变量数不超过总变量数的一半,认为是左子树,即这里把变量少的认为是左子树,之前左子树的索引也是不超过右子树
      categories.contains(binnedFeature.toDouble) //private val categories: Set[Double] = {if (isLeft) {_leftCategories.toSet} else {setComplement(_leftCategories.toSet)}}如果数组长度少的,categories是数组变量集合,如果数组长度长,categories取与另外一个并集,如果这个元素落在短数组里,shouldGoLeft,否则就去长的一边也就是右边了
    } else {
      !categories.contains(binnedFeature.toDouble)//逻辑和上面差不多,这左左右右的还真挺烦
    }
  }
            agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)//featureUpdate跟之前的agg.update逻辑一样,只不过换了个名,从后面看出这里就是区分了个leftNodeFeatureOffset,还是rightNodeFeatureOffset,还有一个差异就是featureValue是不是具体的值,而是10010这种,这两点决定了unordered的逻辑和ordered不同,另外对于类别属性,如果变量值过多,是不会被弄成unordered属性的,因为变量值越多,组合越多,数据量太大了,尽管如此,这样累加计算量确实很大,正如之前注释说的
          } else {
            agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
          }
          splitIndex += 1
        }

终于回到主类,agg始终以数组传递更新,最后返回的还是agg

        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
        // which can be combined with other partition using `reduceByKey`
        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator//把返回的agg复制副本,搞出组索引并交换组索引和值,形成RDD[(Int, DTStatsAggregator)]
      }
    } else {  //如果不使用nodeIdCache就简单了,下面的逻辑就只是上面的一部分,做完这一堆就完成了分区聚合,但是不同分区有相同的索引
      input.mapPartitions { points =>    
        // Construct a nodeStatsAggregators array to hold node aggregate stats,
        // each node will have a nodeStatsAggregator
        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
            Some(nodeToFeatures(nodeIndex))
          }
          new DTStatsAggregator(metadata, featuresForNode)
        }


        // iterator all instances in current partition and update aggregate stats
        points.foreach(binSeqOp(nodeStatsAggregators, _))


        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
        // which can be combined with other partition using `reduceByKey`
        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
      }
    }

    val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map {//刚才说了不同分区,索引相同,现在聚合索引相同的DTStatsAggregator,聚合完后就剩下组元素个DTStatsAggregator
      case (nodeIndex, aggStats) =>       //聚合完剩下组nodeID,DTStatsAggregator
        val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => //找出组元素对应的属性索引数组,和之前一样
          Some(nodeToFeatures(nodeIndex))
        }


        // find best split for each node//find best split!终于看到这几个字了,泪奔
        val (split: Split, stats: ImpurityStats) =
          nodeToBestSplits (aggStats, splits, featuresForNode, nodes(nodeIndex))//binsToBestSplit这就是泪奔的方法,参数aggStats节点聚合信息,splits各种划分,featuresForNode这个节点的属性索引数组,nodes是之前声明一直没用过的,nodes(nodeIndex)是组索引对应的node
        (nodeIndex, (split, stats))  //返回的值加组索引就是nodeToBestSplits 
    }.collectAsMap()

我们在四中介绍nodeToBestSplits



  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值