spark源码分析--GradientBoostedTrees和RandomForest

spark源码分析–GradientBoostedTrees和RandomForest

GradientBoostedTree是spark mllib中的一个树模型,源码见GradientBoostedTrees.scala。该模型仅适用于回归和二分类问题。

训练调用方法

import org.apache.spark.mllib.tree.GradientBoostedTrees

//配置训练策略的参数
val numTrees = 2 //设置树的个数
val boostingStrategy = BoostingStrategy.defaultParams("Classification")//初始化提升策略
boostingStrategy.setNumIterations(numTrees)//为提升策略设置迭代次数,这里其实就是树的个数
val treeStratery = Strategy.defaultStrategy("Classification")//设置树的默认策略为分类问题
treeStratery.setMaxDepth(5)//设置树的最大深度为5
treeStratery.setNumClasses(2)//设置分类类别数2
//    treeStratery.setCategoricalFeaturesInfo(Map[Int, Int]())//可以指定类目特征和连续数值特征
boostingStrategy.setTreeStrategy(treeStratery)//把树策略放置到提升策略中

//输入训练数据train,格式为LabeledPoint,传入提升策略boostingStrategy
val gbdtModel = GradientBoostedTrees.train(train, boostingStrategy)

训练过程分析

训练的时候我们调用的是mllib中树模型的object的方法,首先我们来看下object的调用方法:

@Since("1.2.0")
object GradientBoostedTrees extends Logging {
   

  /**
   * Method to train a gradient boosting model.
   *
   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
   *              For classification, labels should take values {0, 1, ..., numClasses-1}.
   *              For regression, labels are real numbers.
   * @param boostingStrategy Configuration options for the boosting algorithm.
   * @return GradientBoostedTreesModel that can be used for prediction.
   */
  @Since("1.2.0")
  def train(
      input: RDD[LabeledPoint],
      boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
    new GradientBoostedTrees(boostingStrategy, seed = 0).run(input)
  }

  /**
   * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]]
   */
  @Since("1.2.0")
  def train(
      input: JavaRDD[LabeledPoint],
      boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
    train(input.rdd, boostingStrategy)
  }
}

这里第二个train方法提供了一个面向Java API的方法。实质上还是调用第一个train方法,返回一个GradientBoostedTreesModel。这里调用了私有类GradientBoostedTrees实例化了一个对象,然后调用该对象的run方法训练。

//导入ml库中树模型的接口中的GradientBoostedTrees,重命名为NewGBT
import org.apache.spark.ml.tree.impl.{GradientBoostedTrees => NewGBT}
import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint}

@Since("1.2.0")
  def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
    //从传入的提升策略里读取算法,算法只有classification和regression两类
    val algo = boostingStrategy.treeStrategy.algo
    //调用ml树模型中的训练接口,生成学到决策树和对应权重
    val (trees, treeWeights) = NewGBT.run(input.map { point =>
      NewLabeledPoint(point.label, point.features.asML)
    }, boostingStrategy, seed.toLong)
    //返回一个GradientBoostedTreesModel的对象
    new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
  }

接下来让我们来看看ml库中的GBDT接口是怎么实现的:

private[spark] object GradientBoostedTrees extends Logging {
   

  /**
   * Method to train a gradient boosting model
   * @param input Training dataset: RDD of `LabeledPoint`.
   * @param seed Random seed.
   * @return tuple of ensemble models and weights:
   *         (array of decision tree models, array of model weights)
   */
  def run(
      input: RDD[LabeledPoint],
      boostingStrategy: OldBoostingStrategy,
      seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
    val algo = boostingStrategy.treeStrategy.algo
    algo match {
      case OldAlgo.Regression =>
        GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
      case OldAlgo.Classification =>
        // Map labels to -1, +1 so binary classification can be treated as regression.
        val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
        GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
          seed)
      case _ =>
        throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
    }
  }
  def boost(...)
  ....
  }

这里训练前,先对传入的算法策略内的问题进行判别,如果是回归问题,直接调用boost提升方法开始训练,如果是分类问题,那么需要把类别标签(0或者1)映射成-1和1,这样把分类问题转换成回归问题。最后都是调用了boost方法训练,然后我们看下boost是如何训练的:

def boost(
      input: RDD[LabeledPoint],
      validationInput: RDD[LabeledPoint],
      boostingStrategy: OldBoostingStrategy,
      validate: Boolean,
      seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
    val timer = new TimeTracker()
    timer.start("total")
    timer.start("init")

    boostingStrategy.assertValid()

    // 首先读取提升策略中的各项参数配置
    val numIterations = boostingStrategy.numIterations
    //设置迭代次数,这里迭代次数就是树的个数,也就是基学习器的个数,每个基学习器都是一个决策树回归模型
    val baseLearners = new Array[DecisionTreeRegressionModel](numIterations)
    //开辟数组存储每次迭代树的权重,最后把每颗迭代树的结果加权作为最后结果
    val baseLearnerWeights = new Array[Double](numIterations)
    //指定训练的损失函数
    val loss = boostingStrategy.loss
    //指定学习率
    val learningRate = boostingStrategy.learningRate
    // Prepare strategy for individual trees, which use regression with variance impurity.
    val treeStrategy = boostingStrategy.treeStrategy.copy
    val validationTol = boostingStrategy.validationTol
    //使用mllib中的回归算法
    treeStrategy.algo = OldAlgo.Regression
    //使用mllib中的不纯度计算方法
    treeStrategy.impurity = OldVariance
    treeStrategy.assertValid()

    // 对输入的训练数据进行缓存,如果没有指定缓存级别,则使用存于内存和硬盘的级别
    val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) {
      input.persist(StorageLevel.MEMORY_AND_DISK)
      true
    } else {
      false
    }

    // 设置周期性的缓存点,存储中间临时结果,防止意外崩掉前功尽弃
    val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
      treeStrategy.getCheckpointInterval, input.sparkContext)
    val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
      treeStrategy.getCheckpointInterval, input.sparkContext)

    timer.stop("init")

    //开始建树
    logDebug("##########")
    logDebug("Building tree 0")
    logDebug("##########")

    // 初始化、训练第一棵决策树
    timer.start("building tree 0")
    val firstTree = new DecisionTreeRegressor().setSeed(seed)
    val firstTreeModel = firstTree.train(input, treeStrategy)
    //让第一棵树的权重为1.0
    val firstTreeWeight = 1.0
    baseLearners(0) = firstTreeModel
    baseLearnerWeights(0) = firstTreeWeight

    //计算首颗树的预测误差
    var predError: RDD[(Double, Double)] =
      computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
    predErrorCheckpointer.update(predError)
    logDebug("error of gbt = " + predError.values.mean())

    // Note: A model of type regression is used since we require raw prediction
    timer.stop("building tree 0")

    //计算首颗树验证集误差
    var validatePredError: RDD[(Double, Double)] =
      computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
    if (validate) validatePredErrorCheckpointer.update(validatePredError)
    var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
    var bestM = 1

    //训练剩余的m-1颗提升树
    var m = 1
    var doneLearning = false
    while (m < numIterations && !doneLearning) {
      // 用伪残差作为新的标签,更新训练数据
      val data = predError.zip(input).map { case ((pred, _), point) =>
        LabeledPoint(-loss.gradient(pred, point.label), point.features)
      }

      timer.start(s"building tree $m")
      logDebug("###################################################")
      logDebug("Gradient boosting tree iteration " + m)
      logDebug("###################################################")
      //这里实例化了一个决策回归器
      val dt = new DecisionTreeRegressor().setSeed(seed + m)
      val model = dt.train(data, treeS
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值