spark源码分析之随机森林(Random Forest)(一)
spark源码分析之随机森林(Random Forest)(二)
spark源码分析之随机森林(Random Forest)(三)
spark源码分析之随机森林(Random Forest)(四)
7. 构造随机森林
在上面的训练过程可以看到,从根节点topNode中不断向下分裂一直到触发截止条件就构造了一棵树所有的node,因此构造整个森林也是非常简单
//构造
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
//返回rf模型
new RandomForestModel(strategy.algo, trees)
8. 随机森林模型
8.1. TreeEnsembleModel
随机森林RandomForestModel继承自树集合模型TreeEnsembleModel
class TreeEnsembleModel(
protected val algo: Algo,
protected val trees: Array[DecisionTreeModel],
protected val treeWeights: Array[Double],
protected val combiningStrategy: EnsembleCombiningStrategy)
- algo:Regression/Classification
- trees:树数组
- treeWeights:每棵树的权重,在RF中每棵树的权重是相同的,在Adaboost可能是不同的
- combiningStrategy:树合并时的策略,Sum/Average/Vote,分类的话应该是Vote,RF应该是Average,GBDT应该是Sum。
- sumWeights:成员变量,不在参数表中,是treeWeights的sum
预测函数
/**
* Predicts for a single data point using the weighted sum of ensemble predictions.
*
* @param features array representing a single data point
* @return predicted category from the trained model
*/
private def predictBySumming(features: Vector): Double &#