spark1.2.0源码MLlib --- 决策树-03

本文深入探讨Spark 1.2.0版本中MLlib模块的决策树实现,重点在于理解在节点分裂过程中,如何对数据进行汇总以计算不纯度和信息增益。通过分析`DecisionTree.findBestSplits()`方法,特别是`points.foreach(binSeqOp(nodeStatsAggregators, _))`这一步,展示了按分区聚合数据的细节,以及EntropyAggregator的`update()`方法在计算熵中的作用。" 126181758,1944903,Transformer模型解析:PyTorch实现与应用,"['深度学习', '自然语言处理', 'PyTorch', '模型结构']
摘要由CSDN通过智能技术生成

本章重点关注树中各节点分裂过程中,如何将相应的数据进行汇总,以便之后计算节点不纯度及信息增益,最终确定分裂的顺序。


首先,从 DecisionTree.findBestSplits() 开始,这个方法代码很长,按照执行顺序来看,代码如下:

    val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {  //节点缓存的情况
      input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
        // Construct a nodeStatsAggregators array to hold node aggregate stats,
        // each node will have a nodeStatsAggregator
        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
            Some(nodeToFeatures(nodeIndex))
          }
          new DTStatsAggregator(metadata, featuresForNode)
        }

        // iterator all instances in current partition and update aggregate stats
        points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))

        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
        // which can be combined with other partition using `reduceByKey`
        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
      }
    } else {  //节点不缓存的情况
      in
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值