spark源码分析–GradientBoostedTrees和RandomForest
GradientBoostedTree是spark mllib中的一个树模型,源码见GradientBoostedTrees.scala。该模型仅适用于回归和二分类问题。
训练调用方法
import org.apache.spark.mllib.tree.GradientBoostedTrees
//配置训练策略的参数
val numTrees = 2 //设置树的个数
val boostingStrategy = BoostingStrategy.defaultParams("Classification")//初始化提升策略
boostingStrategy.setNumIterations(numTrees)//为提升策略设置迭代次数,这里其实就是树的个数
val treeStratery = Strategy.defaultStrategy("Classification")//设置树的默认策略为分类问题
treeStratery.setMaxDepth(5)//设置树的最大深度为5
treeStratery.setNumClasses(2)//设置分类类别数2
// treeStratery.setCategoricalFeaturesInfo(Map[Int, Int]())//可以指定类目特征和连续数值特征
boostingStrategy.setTreeStrategy(treeStratery)//把树策略放置到提升策略中
//输入训练数据train,格式为LabeledPoint,传入提升策略boostingStrategy
val gbdtModel = GradientBoostedTrees.train(train, boostingStrategy)
训练过程分析
训练的时候我们调用的是mllib中树模型的object的方法,首先我们来看下object的调用方法:
@Since("1.2.0")
object GradientBoostedTrees extends Logging {
/**
* Method to train a gradient boosting model.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param boostingStrategy Configuration options for the boosting algorithm.
* @return GradientBoostedTreesModel that can be used for prediction.
*/
@Since("1.2.0")
def train(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
new GradientBoostedTrees(boostingStrategy, seed = 0).run(input)
}
/**
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]]
*/
@Since("1.2.0")
def train(
input: JavaRDD[LabeledPoint],
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
train(input.rdd, boostingStrategy)
}
}
这里第二个train方法提供了一个面向Java API的方法。实质上还是调用第一个train方法,返回一个GradientBoostedTreesModel。这里调用了私有类GradientBoostedTrees实例化了一个对象,然后调用该对象的run方法训练。
//导入ml库中树模型的接口中的GradientBoostedTrees,重命名为NewGBT
import org.apache.spark.ml.tree.impl.{GradientBoostedTrees => NewGBT}
import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint}
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
//从传入的提升策略里读取算法,算法只有classification和regression两类
val algo = boostingStrategy.treeStrategy.algo
//调用ml树模型中的训练接口,生成学到决策树和对应权重
val (trees, treeWeights) = NewGBT.run(input.map { point =>
NewLabeledPoint(point.label, point.features.asML)
}, boostingStrategy, seed.toLong)
//返回一个GradientBoostedTreesModel的对象
new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
}
接下来让我们来看看ml库中的GBDT接口是怎么实现的:
private[spark] object GradientBoostedTrees extends Logging {
/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of `LabeledPoint`.
* @param seed Random seed.
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
def run(
input: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy,
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case OldAlgo.Regression =>
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
seed)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
}
}
def boost(...)
....
}
这里训练前,先对传入的算法策略内的问题进行判别,如果是回归问题,直接调用boost提升方法开始训练,如果是分类问题,那么需要把类别标签(0或者1)映射成-1和1,这样把分类问题转换成回归问题。最后都是调用了boost方法训练,然后我们看下boost是如何训练的:
def boost(
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy,
validate: Boolean,
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
boostingStrategy.assertValid()
// 首先读取提升策略中的各项参数配置
val numIterations = boostingStrategy.numIterations
//设置迭代次数,这里迭代次数就是树的个数,也就是基学习器的个数,每个基学习器都是一个决策树回归模型
val baseLearners = new Array[DecisionTreeRegressionModel](numIterations)
//开辟数组存储每次迭代树的权重,最后把每颗迭代树的结果加权作为最后结果
val baseLearnerWeights = new Array[Double](numIterations)
//指定训练的损失函数
val loss = boostingStrategy.loss
//指定学习率
val learningRate = boostingStrategy.learningRate
// Prepare strategy for individual trees, which use regression with variance impurity.
val treeStrategy = boostingStrategy.treeStrategy.copy
val validationTol = boostingStrategy.validationTol
//使用mllib中的回归算法
treeStrategy.algo = OldAlgo.Regression
//使用mllib中的不纯度计算方法
treeStrategy.impurity = OldVariance
treeStrategy.assertValid()
// 对输入的训练数据进行缓存,如果没有指定缓存级别,则使用存于内存和硬盘的级别
val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) {
input.persist(StorageLevel.MEMORY_AND_DISK)
true
} else {
false
}
// 设置周期性的缓存点,存储中间临时结果,防止意外崩掉前功尽弃
val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
treeStrategy.getCheckpointInterval, input.sparkContext)
val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
treeStrategy.getCheckpointInterval, input.sparkContext)
timer.stop("init")
//开始建树
logDebug("##########")
logDebug("Building tree 0")
logDebug("##########")
// 初始化、训练第一棵决策树
timer.start("building tree 0")
val firstTree = new DecisionTreeRegressor().setSeed(seed)
val firstTreeModel = firstTree.train(input, treeStrategy)
//让第一棵树的权重为1.0
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
//计算首颗树的预测误差
var predError: RDD[(Double, Double)] =
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
predErrorCheckpointer.update(predError)
logDebug("error of gbt = " + predError.values.mean())
// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")
//计算首颗树验证集误差
var validatePredError: RDD[(Double, Double)] =
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
if (validate) validatePredErrorCheckpointer.update(validatePredError)
var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
var bestM = 1
//训练剩余的m-1颗提升树
var m = 1
var doneLearning = false
while (m < numIterations && !doneLearning) {
// 用伪残差作为新的标签,更新训练数据
val data = predError.zip(input).map { case ((pred, _), point) =>
LabeledPoint(-loss.gradient(pred, point.label), point.features)
}
timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
//这里实例化了一个决策回归器
val dt = new DecisionTreeRegressor().setSeed(seed + m)
val model = dt.train(data, treeS