用pyspark学习《应用预测建模》(九)梯度提升源码初探

GBTRegressor训练的入口是train方法,train调用GradientBoostedTrees的runWithValidation方法或run方法

    val (baseLearners, learnerWeights) = if (withValidation) {
      GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
        $(seed), $(featureSubsetStrategy), Some(instr))
    } else {
      GradientBoostedTrees.run(trainDataset, boostingStrategy,
        $(seed), $(featureSubsetStrategy), Some(instr))
    }

runWithValidation或run再调用boost方法,boost完成实际的训练过程

首先调用RandomForest.findSplits准备splits,findSplits调用findSplitsBySorting。splits的类型是Array[Array[Split]]。

然后开始训练第一棵树

    val bcSplits = sc.broadcast(splits)

    // Bin feature values (TreePoint representation).
    // Cache input RDD for speedup during multiple passes.
    val treePoints = TreePoint.convertToTreeRDD(
      retaggedInput, splits, metadata)
      .persist(StorageLevel.MEMORY_AND_DISK)
      .setName("binned tree points")

    val firstCounts = BaggedPoint
      .convertToBaggedRDD(treePoints, treeStrategy.subsamplingRate, numSubsamples = 1,
        treeStrategy.bootstrap, (tp: TreePoint) => tp.weight, seed = seed)
      .map { bagged =>
        require(bagged.subsampleCounts.length == 1)
        require(bagged.sampleWeight == bagged.datum.weight)
        bagged.subsampleCounts.head
      }.persist(StorageLevel.MEMORY_AND_DISK)
      .setName("firstCounts at iter=0")

    val firstBagged = treePoints.zip(firstCounts)
      .map { case (treePoint, count) =>
        // according to current design, treePoint.weight == baggedPoint.sampleWeight
        new BaggedPoint[TreePoint](treePoint, Array(count), treePoint.weight)
    }

    val firstTreeModel = RandomForest.runBagged(baggedInput = firstBagged,
      metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy, numTrees = 1,
      featureSubsetStrategy = featureSubsetStrategy, seed = seed, instr = instr,
      parentUID = None)
      .head.asInstanceOf[DecisionTreeRegressionModel]

根据gbt的流程,计算残差

    val firstTreeWeight = 1.0
    baseLearners(0) = firstTreeModel
    baseLearnerWeights(0) = firstTreeWeight

    var predError = computeInitialPredictionAndError(
      treePoints, firstTreeWeight, firstTreeModel, loss, bcSplits)
    predErrorCheckpointer.update(predError)
    logDebug(s"error of gbt = ${computeWeightedError(treePoints, predError)}")
  /**
   * Compute the initial predictions and errors for a dataset for the first
   * iteration of gradient boosting.
   * @param data: training data.
   * @param initTreeWeight: learning rate assigned to the first tree.
   * @param initTree: first DecisionTreeModel.
   * @param loss: evaluation metric.
   * @return an RDD with each element being a zip of the prediction and error
   *         corresponding to every sample.
   */
  def computeInitialPredictionAndError(
      data: RDD[TreePoint],
      initTreeWeight: Double,
      initTree: DecisionTreeRegressionModel,
      loss: OldLoss,
      bcSplits: Broadcast[Array[Array[Split]]]): RDD[(Double, Double)] = {
    data.map { treePoint =>
      val pred = updatePrediction(treePoint, 0.0, initTree, initTreeWeight, bcSplits.value)
      val error = loss.computeError(pred, treePoint.label)
      (pred, error)
    }
  }

下面在while循环中训练更多的树

    var bestM = 1

    var m = 1
    var doneLearning = false
    while (m < numIterations && !doneLearning) {
      timer.start(s"building tree $m")
      logDebug("###################################################")
      logDebug("Gradient boosting tree iteration " + m)
      logDebug("###################################################")

      // (label: Double, count: Int)
      val labelWithCounts = BaggedPoint
        .convertToBaggedRDD(treePoints, treeStrategy.subsamplingRate, numSubsamples = 1,
          treeStrategy.bootstrap, (tp: TreePoint) => tp.weight, seed = seed + m)
        .zip(predError)
        .map { case (bagged, (pred, _)) =>
          require(bagged.subsampleCounts.length == 1)
          require(bagged.sampleWeight == bagged.datum.weight)
          // Update labels with pseudo-residuals
          val newLabel = -loss.gradient(pred, bagged.datum.label)
          (newLabel, bagged.subsampleCounts.head)
        }.persist(StorageLevel.MEMORY_AND_DISK)
        .setName(s"labelWithCounts at iter=$m")

      val bagged = treePoints.zip(labelWithCounts)
        .map { case (treePoint, (newLabel, count)) =>
          val newTreePoint = new TreePoint(newLabel, treePoint.binnedFeatures, treePoint.weight)
          // according to current design, treePoint.weight == baggedPoint.sampleWeight
          new BaggedPoint[TreePoint](newTreePoint, Array(count), treePoint.weight)
        }

      val model = RandomForest.runBagged(baggedInput = bagged,
        metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy,
        numTrees = 1, featureSubsetStrategy = featureSubsetStrategy,
        seed = seed + m, instr = None, parentUID = None)
        .head.asInstanceOf[DecisionTreeRegressionModel]

      labelWithCounts.unpersist()

      timer.stop(s"building tree $m")
      // Update partial model
      baseLearners(m) = model
      // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
      //       Technically, the weight should be optimized for the particular loss.
      //       However, the behavior should be reasonable, though not optimal.
      baseLearnerWeights(m) = learningRate

      predError = updatePredictionError(
        treePoints, predError, baseLearnerWeights(m),
        baseLearners(m), loss, bcSplits)
      predErrorCheckpointer.update(predError)
      logDebug(s"error of gbt = ${computeWeightedError(treePoints, predError)}")

      if (validate) {
        // Stop training early if
        // 1. Reduction in error is less than the validationTol or
        // 2. If the error increases, that is if the model is overfit.
        // We want the model returned corresponding to the best validation error.

        validatePredError = updatePredictionError(
          validationTreePoints, validatePredError, baseLearnerWeights(m),
          baseLearners(m), loss, bcSplits)
        validatePredErrorCheckpointer.update(validatePredError)
        val currentValidateError = computeWeightedError(validationTreePoints, validatePredError)
        if (bestValidateError - currentValidateError < validationTol * Math.max(
          currentValidateError, 0.01)) {
          doneLearning = true
        } else if (currentValidateError < bestValidateError) {
          bestValidateError = currentValidateError
          bestM = m + 1
        }
      }
      m += 1
    }

根据J.H. Friedman的论文,实际上不是拟合残差而是拟合梯度

val newLabel = -loss.gradient(pred, bagged.datum.label)

至此,已经梳理了gbt的轮廓。当然还留下了很多的细节没有彻底弄清楚,如果以后有需要,再来研究吧O(∩_∩)O

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值