spark ml 随机森林源码笔记四

继续binsToBestSplit,这个函数是找到一个节点的最优划分

    // Calculate InformationGain and ImpurityStats if current node is top node 如果当前节点是根节点计算信息增益和不纯性统计
    val level = LearningNode.indexToLevel(node.id) //indexToLevel是根据node.id计算level,这个level是从0开始

    var gainAndImpurityStats: ImpurityStats = if (level ==0) {null} else { node.stats}//如果是根节点增益和不纯性统计是null,因为构造跟节点的时候只放了nodeIndex,不是根节点就是node.stats,这说明后面会计算stats

    // For each (feature, split), calculate the gain, and select the best (feature, split).//对每个属性的每个划分,计算增益,并选择最优的属性划分
    val (bestSplit, bestSplitStats) = //返回最优划分,最优划分的统计
      Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>//循环节点的每个属性
        val featureIndex = if (featuresForNode.nonEmpty) { //找到每个属性的索引
          featuresForNode.get.apply(featureIndexIdx)
        } else {
          featureIndexIdx
        }
        val numSplits = binAggregates.metadata.numSplits(featureIndex)//求出这个属性有多少划分
        if (binAggregates.metadata.isContinuous(featureIndex)) {//下面分三种情况计算,连续属性,unordered属性,ordered属性
          // Cumulative sum (scanLeft) of bin statistics. //聚合bin统计,这样每个bin的聚合信息就是包括这个bin和之前bin的聚合信息
          // Afterwards, binAggregates for a bin is the sum of aggregates for
          // that bin + all preceding bins.
          val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)//计算属性索引的偏移
          var splitIndex = 0
          while (splitIndex < numSplits) {//遍历每个划分
            binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)//把前面的bin聚合的后面的bin上,改变的是后面的bin,这样就实现了注释说的功能
            splitIndex += 1
          }
          // Find best split.
          val (bestFeatureSplitIndex, bestFeatureGainStats) =//寻找最优划分,返回最优属性划分索引和最优属性增益统计
            Range(0, numSplits).map { case splitIdx =>//对每个划分
              val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)//获取属性划分的不纯性值,这时统计值已经被更新成累计值了
              val rightChildStats =
                binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)//获取属性总的不纯性值

              rightChildStats.subtract(leftChildStats)//不纯性相减,结果更新到rightChildStats

下面我们先看下calculateImpurityStats这个方法,就是计算不纯性的

  private def calculateImpurityStats(
      stats: ImpurityStats,//不纯性统计,如果是根节点是null,这是我们上面提到的
      leftImpurityCalculator: ImpurityCalculator,//左不纯性统计
      rightImpurityCalculator: ImpurityCalculator,//右不纯性统计,根据前面计算结果已经是不纯性的差了
      metadata: DecisionTreeMetadata): ImpurityStats = {//元数据


    val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {//父节点不纯性统计
      leftImpurityCalculator.copy.add(rightImpurityCalculator)//如果是根节点把左右相加,其实还是最初numSplit对应的不纯性统计
    } else {
      stats.impurityCalculator//否则取传过来的不纯性统计
    }


    val impurity: Double = if (stats == null) {
      parentImpurityCalculator.calculate()//根节点计算不纯性值,计算公式就是    val squaredLoss = sumSquares - (sum * sum) / count,返回squaredLoss / count,用的都是统计信息
    } else {
      stats.impurity//否则直接取传过来的不纯性值
    }


    val leftCount = leftImpurityCalculator.count//左不纯性包含的实例数量
    val rightCount = rightImpurityCalculator.count//右不纯性包含实例数量


    val totalCount = leftCount + rightCount//总数


    // If left child or right child doesn't satisfy minimum instances per node,//如果左右子树不满足最小实例条件,划分无效,返回无效信息增益统计
    // then this split is invalid, return invalid information gain stats.
    if ((leftCount < metadata.minInstancesPerNode) ||
      (rightCount < metadata.minInstancesPerNode)) {
      return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)//返回一个无效的不纯性统计
    }


    val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0//获取左子树不纯性
    val rightImpurity = rightImpurityCalculator.calculate()//获取右子树不纯性


    val leftWeight = leftCount / totalCount.toDouble//左子树权重
    val rightWeight = rightCount / totalCount.toDouble//右子树权重


    val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity//计算信息增益


    // if information gain doesn't satisfy minimum information gain,//如果信息增益不满足最小增益量,返回无效信息增益统计
    // then this split is invalid, return invalid information gain stats.
    if (gain < metadata.minInfoGain) {
     

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值