Spark2.3 源码解析 之 随机森林 random forest

Spark2.3 源码解析 之 随机森林 random forest简介1. bagging如下图所示,bagging的思想“三个臭皮匠胜过诸葛亮” ,即训练多个弱分类器,之后大家共同产生最终结果:投票表决或者平均值。其中每个若分类器之间没有前后关联(与boosting区别),训练若分类器的前提就是随机采样。这里的抽样是有放回随机抽样,一般每个采样集和训练集的数量一致,即每个采样...
摘要由CSDN通过智能技术生成

Spark2.3 源码解析 之 随机森林 random forest

简介

1. bagging

如下图所示,bagging的思想“三个臭皮匠胜过诸葛亮” ,即训练多个弱分类器,之后大家共同产生最终结果:投票表决或者平均值。

其中每个若分类器之间没有前后关联(与boosting区别),训练若分类器的前提就是随机采样。这里的抽样是有放回随机抽样,一般每个采样集和训练集的数量一致,即每个采样集也要采样m个样本。对于K个若分类器,就要进行K次随机采样,由此得到K个不同的样本子集。

对于一个样本,它在某一次含m个样本的训练集的随机采样中,每次被采集到的概率是1/m。不被采集到的概率为1−1/m。如果m次采样都没有被采集中的概率是(1-1/m)的m次方。当m无穷大这个公式约等于1/e≃0.368。 因此,在bagging的每轮随机采样中,训练集中大约有36.8%的数据没有被采样集采集中。这部分数据被称之为袋外数据(Out Of Bag, 简称OOB)。这些数据没有参与训练集模型的拟合,因此可以用来检测模型的泛化能力

  1. 抽样角度,与GBDT简单对比
  2. bagging是有放回抽样,而GBDT是无放回抽样
  3. bagging每个都是弱分类器都是训练样本的采样,因此泛化能力强。depth通常大于gbdt,不易过拟合

2. random forest 随机森林

理论上,bagging的若分类器可以是任意模型,但是较为通用的若分类器主要有两种:一个是决策树,另一个就是神经网络。当若分类器为决策树是,bagging就变成了随机森林。以下就spark2.3的random forest的代码进行详细介绍。
需要强调的是:spark的实现中,不仅进行了样本采样,同时也进行了特征抽样(列采样)
Tips:关于决策树的知识请参考上一篇文章决策树 decision tree

一、整体思路

1、Stack存储Nodes

(1)Forest中的所有tree一起处理;
(2)利用Stack存储Nodes:如果当前正在split某个tree的节点,那么新的Nodes也属于该tree,因此下一轮迭代训练的是同一棵tree。虽然同时训练多个tree,但并不是并行训练多个tree,而是focus on completing trees。
(3)每次在Stack中取出若干个Nodes,组成nodesForGroup,一并处理:RandomForest.findBestSplits,一起寻找最优

 

代码位于RandomForest.scala的run方法中
/*
      Stack of nodes to train: (treeIndex, node)
      The reason this is a stack is that we train many trees at once, but we want to focus on
      completing trees, rather than training all simultaneously.  If we are splitting nodes from
      1 tree, then the new nodes to split will be put at the top of this stack, so we will continue
      training the same tree in the next iteration.  This focus allows us to send fewer trees to
      workers on each iteration; see topNodesForGroup below.
     */
    val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]

    val rng = new Random()
    rng.setSeed(seed)

    // Allocate and queue root nodes.
    val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
    Range(0, numTrees).foreach(treeIndex => nodeStack.push((treeIndex, topNodes(treeIndex))))

    timer.stop("init")

    while (nodeStack.nonEmpty) {
      // Collect some nodes to split, and choose features for each node (if subsampling).
      // Each group of nodes may come from one or multiple trees, and at multiple levels.
      val (nodesForGroup, treeToNodeToIndexInfo) =
        RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
      // Sanity check (should never occur):
      assert(nodesForGroup.nonEmpty,
        s"RandomForest selected empty nodesForGroup.  Error for unknown reason.")

      //找到nodes对应的树的根节点,只需要相应的trees
      // Only send trees to worker if they contain nodes being split this iteration.
      val topNodesForGroup: Map[Int, LearningNode] =
        nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap

      // Choose node splits, and enqueue new nodes as needed.
      timer.start("findBestSplits")
      RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, nodesForGroup,
        treeToNodeToIndexInfo, splits, nodeStack, timer, nodeIdCache)
      timer.stop("findBestSplits")
    }

2、若干Nodes寻找最优

  1. 首先,每个分区内部循环所有数据,计算相应统计量;
  2. 每个分区计算完成后,通过reduceByKey,将同一个node的数据merge到一起;
  3. 最后,对每个node的统计量,寻找最后split;
下面粘贴整体思路的代码(RandomForest.scala的findBestSplits方法中)
// In each partition, iterate all instances and compute aggregate stats for each node,
    // yield a (nodeIndex, nodeAggregateStats) pair for each node.
    // After a `reduceByKey` operation,
    // stats of a node will be shuffled to a particular partition and be combined together,
    // then best splits for nodes are found there.
    // Finally, only best Splits for nodes are collected to driver to construct decision tree.
    val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
    val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)

    //利用mapPartition计算每个partition的统计量
    val partitionAggregates: RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
      input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
        // Construct a nodeStatsAggregators array to hold node aggregate stats,
        // each node will have a nodeStatsAggregator
        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
          val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
            nodeToFeatures(nodeIndex)
          }
          new DTStatsAggregator(metadata, featuresForNode)
        }

        // iterator all instances in current partition and update aggregate stats
        points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))

        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
        // which can be combined with other partition using `reduceByKey`
        nodeStatsAggregators.view.zip
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值