spark1.2.0源码MLlib --- 决策树-02

本章重点关注分裂节点的划分。


具体执行的代码为:DecisionTree.findSplitsBins() :

  protected[tree] def findSplitsBins(
      input: RDD[LabeledPoint],
      metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {

    logDebug("isMulticlass = " + metadata.isMulticlass)

    val numFeatures = metadata.numFeatures  //特征数

    // Sample the input only if there are continuous features.
    val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)  //判断特征的属性值是否连续
    val sampledInput = if (hasContinuousFeatures) {  //连续的情况
      // Calculate the number of samples for approximate quantile calculation.
      val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)  //采用子样本来确定分裂的位置,最少为10000个样本
      val fraction = if (requiredSamples < metadata.numExamples) {
        requiredSamples.toDouble / metadata.numExamples  //计算采样比例
      } else {
        1.0
      }
      logDebug("fraction of data used for calculating quantiles = " + fraction)
      input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
    } else {  //离散的情况
      new Array[LabeledPoint](0)
    }

    metadata.quantileStrategy match {  //分裂点策略,目前只有一种:带排序的
      case Sort =>
        val splits = new Array[Array[Split]](numFeatures) //每个特征对应一组分裂位置
        val bins = new Array[Array[Bin]](numFeatures) //可以看做和分裂位置一一对应的一个箱子,里面存放一些分裂位置信息

        // Find all splits.
        // Iterate over all features.
        var featureIndex = 0
        while (featureIndex < numFeatures) {
          if (metadata.isContinuous(featureIndex)) {  //特征属性值为连续的情况
            val featureSamples = sampledInput.map(lp => lp.features(featureIndex)) //对应featureIndex特征索引的属性值
            val featureSplits = findSplitsForContinuousFeature(featureSamples,  //得到每个特征索引的分裂点位置(有多个)
              metadata, featureIndex)

            val numSplits = featureSplits.length  //这个特征索引对应的分裂位置的个数
            val numBins = numSplits + 1  //对应的箱子数,(例如:切一个西瓜两刀,那么会得到三块西瓜)
            logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
            splits(featureIndex) = new Array[Split](numSplits)  //分裂点数组
            bins(featureIndex) = new Array[Bin](numBins)   //对应的箱子数组

            var splitIndex = 0
            while (splitIndex < numSplits) {
              val threshold = featureSplits(splitIndex)  //每个分裂点对应的值,因为是带排序的,所以可以看做是一个阈值
              splits(featureIndex)(splitIndex) =
                new Split(featureIndex, threshold, Continuous, List()) //每个特征索引对应一组(多个)分裂位置信息
              splitIndex += 1
            }
            bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
              splits(featureIndex)(0), Continuous, Double.MinValue)  //创建最左边的分裂位置(也就是最小的阈值,为0)

            splitIndex = 1
            while (splitIndex < numSplits) {
              bins(featureIndex)(splitIndex) =
                new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),  //箱子里将存放的是两个分裂点位置阈值区间的属性值
                  Continuous, Double.MinValue)
              splitIndex += 1
            }
            bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
              new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) //最后的那个箱子信息
          } else {   //特征属性值为离散的情况
            val numSplits = metadata.numSplits(featureIndex)  //分裂点的个数
            val numBins = metadata.numBins(featureIndex)   //箱子个数
            // Categorical feature
            val featureArity = metadata.featureArity(featureIndex)  //获得特征索引的离散属性个数
            if (metadata.isUnordered(featureIndex)) {  //特征离散且无序的情况
              // TODO: The second half of the bins are unused.  Actually, we could just use
              //       splits and not build bins for unordered features.  That should be part of
              //       a later PR since it will require changing other code (using splits instead
              //       of bins in a few places).
              // Unordered features
              //   2^(maxFeatureValue - 1) - 1 combinations
              splits(featureIndex) = new Array[Split](numSplits)
              bins(featureIndex) = new Array[Bin](numBins)
              var splitIndex = 0
              while (splitIndex < numSplits) {
                val categories: List[Double] =
                  extractMultiClassCategories(splitIndex + 1, featureArity) //提取特征的属性值,返回集合包含其中一个或多个的离散属性值
                splits(featureIndex)(splitIndex) =
                  new Split(featureIndex, Double.MinValue, Categorical, categories)
                bins(featureIndex)(splitIndex) = {
                  if (splitIndex == 0) {
                    new Bin(
                      new DummyCategoricalSplit(featureIndex, Categorical),
                      splits(featureIndex)(0),
                      Categorical,
                      Double.MinValue)
                  } else {
                    new Bin(
                      splits(featureIndex)(splitIndex - 1),
                      splits(featureIndex)(splitIndex),
                      Categorical,
                      Double.MinValue)
                  }
                }
                splitIndex += 1
              }
            } else {  //特征离散且有序的情况,不需要提前计算,在训练期间创建
              // Ordered features
              //   Bins correspond to feature values, so we do not need to compute splits or bins
              //   beforehand.  Splits are constructed as needed during training.
              splits(featureIndex) = new Array[Split](0)
              bins(featureIndex) = new Array[Bin](0)
            }
          }
          featureIndex += 1
        }
        (splits, bins)
      case MinMax =>
        throw new UnsupportedOperationException("minmax not supported yet.")
      case ApproxHist =>
        throw new UnsupportedOperationException("approximate histogram not supported yet.")
    }
  }

接下来,看看怎么确定分裂点的位置(特征属性值为连续的情况):

  private[tree] def findSplitsForContinuousFeature(  //计算特征索引的分裂点位置
      featureSamples: Array[Double],
      metadata: DecisionTreeMetadata,
      featureIndex: Int): Array[Double] = {
    require(metadata.isContinuous(featureIndex),
      "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")

    val splits = {
      val numSplits = metadata.numSplits(featureIndex)   //分裂数

      // get count for each distinct value
      val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
        m + ((x, m.getOrElse(x, 0) + 1))   //得到每个离散属性值及出现的次数:map(value,count)
      }
      // sort distinct values
      val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray  //将特征索引对应的每个离散属性值进行排序,升序

      // if possible splits is not enough or just enough, just return all possible splits
      val possibleSplits = valueCounts.length
      if (possibleSplits <= numSplits) {
        valueCounts.map(_._1)  //每个离散属性值对应一个分裂点
      } else {  //否则,将会合并部分离散属性值到一个分裂点
        // stride between splits
        val stride: Double = featureSamples.length.toDouble / (numSplits + 1)  //每个分裂点之前至少要包含的样本数
        logDebug("stride = " + stride)

        // iterate `valueCount` to find splits
        val splits = new ArrayBuffer[Double]
        var index = 1
        // currentCount: sum of counts of values that have been visited
        var currentCount = valueCounts(0)._2  //从第一个数组开始扫描,值为特征索引的离散属性值所包含的个数
        // targetCount: target value for `currentCount`.
        // If `currentCount` is closest value to `targetCount`,
        // then current value is a split threshold.
        // After finding a split threshold, `targetCount` is added by stride.
        var targetCount = stride
        while (index < valueCounts.length) {
          val previousCount = currentCount
          currentCount += valueCounts(index)._2
          val previousGap = math.abs(previousCount - targetCount)
          val currentGap = math.abs(currentCount - targetCount) //实际上就是当累加数大于targetCount时,就是分裂点的位置
          // If adding count of current value to currentCount
          // makes the gap between currentCount and targetCount smaller,
          // previous value is a split threshold.
          if (previousGap < currentGap) {
            splits.append(valueCounts(index - 1)._1)  //把满足条件的特征索引的值加到splits数组中
            targetCount += stride
          }
          index += 1
        }

        splits.toArray
      }
    }

    assert(splits.length > 0)
    // set number of splits accordingly
    metadata.setNumSplits(featureIndex, splits.length)  //重新设置(更新)一下相应的信息

    splits
  }

 *************  The End  ************* 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值