GBTRegressor训练的入口是train方法,train调用GradientBoostedTrees的runWithValidation方法或run方法
val (baseLearners, learnerWeights) = if (withValidation) {
GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
$(seed), $(featureSubsetStrategy), Some(instr))
} else {
GradientBoostedTrees.run(trainDataset, boostingStrategy,
$(seed), $(featureSubsetStrategy), Some(instr))
}
runWithValidation或run再调用boost方法,boost完成实际的训练过程
首先调用RandomForest.findSplits准备splits,findSplits调用findSplitsBySorting。splits的类型是Array[Array[Split]]。
然后开始训练第一棵树
val bcSplits = sc.broadcast(splits)
// Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
val treePoints = TreePoint.convertToTreeRDD(
retaggedInput, splits, metadata)
.persist(StorageLevel.MEMORY_AND_DISK)
.setName("binned tree points")
val firstCounts = BaggedPoint
.convertToBaggedRDD(treePoints, treeStrategy.subsamplingRate, numSubsamples = 1,
treeStrategy.bootstrap, (tp: TreePoint) => tp.weight, seed = seed)
.map { bagged =>
require(bagged.subsampleCounts.length == 1)
require(bagged.sampleWeight == bagged.datum.weight)
bagged.subsampleCounts.head
}.persist(StorageLevel.MEMORY_AND_DISK)
.setName("firstCounts at iter=0")
val firstBagged = treePoints.zip(firstCounts)
.map { case (treePoint, count) =>
// according to current design, treePoint.weight == baggedPoint.sampleWeight
new BaggedPoint[TreePoint](treePoint, Array(count), treePoint.weight)
}
val firstTreeModel = RandomForest.runBagged(baggedInput = firstBagged,
metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy, numTrees = 1,
featureSubsetStrategy = featureSubsetStrategy, seed = seed, instr = instr,
parentUID = None)
.head.asInstanceOf[DecisionTreeRegressionModel]
根据gbt的流程,计算残差
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
var predError = computeInitialPredictionAndError(
treePoints, firstTreeWeight, firstTreeModel, loss, bcSplits)
predErrorCheckpointer.update(predError)
logDebug(s"error of gbt = ${computeWeightedError(treePoints, predError)}")
/**
* Compute the initial predictions and errors for a dataset for the first
* iteration of gradient boosting.
* @param data: training data.
* @param initTreeWeight: learning rate assigned to the first tree.
* @param initTree: first DecisionTreeModel.
* @param loss: evaluation metric.
* @return an RDD with each element being a zip of the prediction and error
* corresponding to every sample.
*/
def computeInitialPredictionAndError(
data: RDD[TreePoint],
initTreeWeight: Double,
initTree: DecisionTreeRegressionModel,
loss: OldLoss,
bcSplits: Broadcast[Array[Array[Split]]]): RDD[(Double, Double)] = {
data.map { treePoint =>
val pred = updatePrediction(treePoint, 0.0, initTree, initTreeWeight, bcSplits.value)
val error = loss.computeError(pred, treePoint.label)
(pred, error)
}
}
下面在while循环中训练更多的树
var bestM = 1
var m = 1
var doneLearning = false
while (m < numIterations && !doneLearning) {
timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
// (label: Double, count: Int)
val labelWithCounts = BaggedPoint
.convertToBaggedRDD(treePoints, treeStrategy.subsamplingRate, numSubsamples = 1,
treeStrategy.bootstrap, (tp: TreePoint) => tp.weight, seed = seed + m)
.zip(predError)
.map { case (bagged, (pred, _)) =>
require(bagged.subsampleCounts.length == 1)
require(bagged.sampleWeight == bagged.datum.weight)
// Update labels with pseudo-residuals
val newLabel = -loss.gradient(pred, bagged.datum.label)
(newLabel, bagged.subsampleCounts.head)
}.persist(StorageLevel.MEMORY_AND_DISK)
.setName(s"labelWithCounts at iter=$m")
val bagged = treePoints.zip(labelWithCounts)
.map { case (treePoint, (newLabel, count)) =>
val newTreePoint = new TreePoint(newLabel, treePoint.binnedFeatures, treePoint.weight)
// according to current design, treePoint.weight == baggedPoint.sampleWeight
new BaggedPoint[TreePoint](newTreePoint, Array(count), treePoint.weight)
}
val model = RandomForest.runBagged(baggedInput = bagged,
metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy,
numTrees = 1, featureSubsetStrategy = featureSubsetStrategy,
seed = seed + m, instr = None, parentUID = None)
.head.asInstanceOf[DecisionTreeRegressionModel]
labelWithCounts.unpersist()
timer.stop(s"building tree $m")
// Update partial model
baseLearners(m) = model
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
// Technically, the weight should be optimized for the particular loss.
// However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate
predError = updatePredictionError(
treePoints, predError, baseLearnerWeights(m),
baseLearners(m), loss, bcSplits)
predErrorCheckpointer.update(predError)
logDebug(s"error of gbt = ${computeWeightedError(treePoints, predError)}")
if (validate) {
// Stop training early if
// 1. Reduction in error is less than the validationTol or
// 2. If the error increases, that is if the model is overfit.
// We want the model returned corresponding to the best validation error.
validatePredError = updatePredictionError(
validationTreePoints, validatePredError, baseLearnerWeights(m),
baseLearners(m), loss, bcSplits)
validatePredErrorCheckpointer.update(validatePredError)
val currentValidateError = computeWeightedError(validationTreePoints, validatePredError)
if (bestValidateError - currentValidateError < validationTol * Math.max(
currentValidateError, 0.01)) {
doneLearning = true
} else if (currentValidateError < bestValidateError) {
bestValidateError = currentValidateError
bestM = m + 1
}
}
m += 1
}
根据J.H. Friedman的论文,实际上不是拟合残差而是拟合梯度
val newLabel = -loss.gradient(pred, bagged.datum.label)
至此,已经梳理了gbt的轮廓。当然还留下了很多的细节没有彻底弄清楚,如果以后有需要,再来研究吧O(∩_∩)O