最后我们再来说说如何做预测,四中讲到run方法最终会返回RandomForestRegressionModel,而预测就是这个类的方法,实际中模型会调用transform,而transformImpl是transform里的具体实现方法,复写后如下
override protected def transformImpl(dataset: DataFrame): DataFrame = {//放入要预测的数据,返回加入预测列的数据
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)//广播当前模型
val predictUDF = udf { (features: Any) =>//定义了一个udf,用predict预测
bcastModel.value.predict(features.asInstanceOf[Vector])//通过asInstanceOf把col类转成Vector处理
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))//withColumn方法很常用,第一个参数是column名,第二个参数是列值
}
override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.//其实就是取的平均值,从之前的讲解我们知道,就是节点的对于数据的加权平均,质心
// Ignore the weights since all are 1.0 for now.
_trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees//_trees是RandomForestRegressionModel组,rootNode是每个model的根节点,如下node节点有这个方法,返回叶节点,而prediction是叶节点的成员,所有树的预测结果求平均就是最终的预测值
}
node节点的predictImpl方法,这个方法就是给出一个特征向量,让这个向量从根节点移动到被划分的叶节点
override private[ml] def predictImpl(features: Vector): LeafNode = {
if (split.shouldGoLeft(features)) {根据节点里存的划分信息,判断向量走向,下面都是递归调用,最终返回叶节点
leftChild.predictImpl(features)
} else {
rightChild.predictImpl(features)
}
}
至此,从源码层面随机森林讲完了,有很多东西需要从不同维度再仔细思考,欢迎提出宝贵意见