spark ml 随机森林源码笔记五

最后我们再来说说如何做预测,四中讲到run方法最终会返回RandomForestRegressionModel,而预测就是这个类的方法,实际中模型会调用transform,而transformImpl是transform里的具体实现方法,复写后如下

  override protected def transformImpl(dataset: DataFrame): DataFrame = {//放入要预测的数据,返回加入预测列的数据
    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)//广播当前模型
    val predictUDF = udf { (features: Any) =>//定义了一个udf,用predict预测
      bcastModel.value.predict(features.asInstanceOf[Vector])//通过asInstanceOf把col类转成Vector处理
    }
    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))//withColumn方法很常用,第一个参数是column名,第二个参数是列值
  }

  override protected def predict(features: Vector): Double = {
    // TODO: When we add a generic Bagging class, handle transform there.  SPARK-7128
    // Predict average of tree predictions.//其实就是取的平均值,从之前的讲解我们知道,就是节点的对于数据的加权平均,质心
    // Ignore the weights since all are 1.0 for now.
    _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees//_trees是RandomForestRegressionModel组,rootNode是每个model的根节点,如下node节点有这个方法,返回叶节点,而prediction是叶节点的成员,所有树的预测结果求平均就是最终的预测值
  }

node节点的predictImpl方法,这个方法就是给出一个特征向量,让这个向量从根节点移动到被划分的叶节点

  override private[ml] def predictImpl(features: Vector): LeafNode = {
    if (split.shouldGoLeft(features)) {根据节点里存的划分信息,判断向量走向,下面都是递归调用,最终返回叶节点
      leftChild.predictImpl(features)
    } else {
      rightChild.predictImpl(features)
    }
  }

至此,从源码层面随机森林讲完了,有很多东西需要从不同维度再仔细思考,欢迎提出宝贵意见

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值