spark树模型源码比较复杂,这里争取把轮廓梳理出来。
回顾一下树模型。对于回归问题,一个节点分裂的规则是,分成两组后的误差平方和最小。对于每一个特征,找到误差平方和最小的分裂点,然后选取误差平方和最小的特征的分裂点作为分裂点。spark的实现没有对每个特征进行排序,而是采用了直方图。
由于spark的决策树实现就是一棵树的随机森林,所以直接看随机森林吧。
RandomForestRegressor的train方法调用org.apache.spark.ml.tree.impl.RandomForest的run,run调用runBagged,能看出这包含抽样的过程。
runBagged在一个nodeStack上进行迭代,每棵树有一个treeIndex。topNodes是一个存放树根的数组,最后借助它来构造随机森林模型。
/*
Stack of nodes to train: (treeIndex, node)
The reason this is a stack is that we train many trees at once, but we want to focus on
completing trees, rather than training all simultaneously. If we are splitting nodes from
1 tree, then the new nodes to split will be put at the top of this stack, so we will continue
training the same tree in the next iteration. This focus allows us to send fewer trees to
workers on each iteration; see topNodesForGroup below.
*/
val nodeStack = new mutable.ListBuffer[(Int, LearningNode)]
val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
for (treeIndex <- 0 until numTrees) {
nodeStack.prepend((treeIndex, topNodes(treeIndex)))
}
while (nodeStack.nonEmpty) {
// Colle