spark中调用xgboost实现2分类时需对rawPrediction修改

1、修改代码

/**
   *  XGBoost模型预测结果修正,适配Spark内置评估方法
   *
   * @param res XGBoost模型预测结果
   * @return 修正后的预测结果
   */

  private def amendXGBPred(res: DataFrame): DataFrame = {
    val columns = res.columns

    if (columns.contains("rawPrediction")) {
      val aRes  = res.withColumnRenamed("rawPrediction", "rawPrediction_Ori")
      val code = (arg: Vector) => {//这个函数使原来的vector,变成新的vector
        val rawPre = arg.apply(0)
        new DenseVector(Array(-1.0 * rawPre, rawPre))
      }
      val addCol = udf(code)
      val columns = aRes.columns
      aRes.selectExpr(columns:_*).withColumn("rawPrediction", addCol(aRes("rawPrediction_Ori")))
    } else {
      res
    }

2、原由

一般的2分类模型,rawPrediction有2列,一列是分类为0的原始预测数值、一列是1的;

但,XGBoost 0.81 Java 开源版本有bug,二分类预测结果rawPrediction只有一列数据,是分类为1的预测数值;

而,spark内置的交叉验证源代码评估时使用的是rawPrediction列,因此对XGBoost算法产生的rawPrediction列进行下调整修改;

3、图片

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值