1、修改代码
/**
* XGBoost模型预测结果修正,适配Spark内置评估方法
*
* @param res XGBoost模型预测结果
* @return 修正后的预测结果
*/
private def amendXGBPred(res: DataFrame): DataFrame = {
val columns = res.columns
if (columns.contains("rawPrediction")) {
val aRes = res.withColumnRenamed("rawPrediction", "rawPrediction_Ori")
val code = (arg: Vector) => {//这个函数使原来的vector,变成新的vector
val rawPre = arg.apply(0)
new DenseVector(Array(-1.0 * rawPre, rawPre))
}
val addCol = udf(code)
val columns = aRes.columns
aRes.selectExpr(columns:_*).withColumn("rawPrediction", addCol(aRes("rawPrediction_Ori")))
} else {
res
}
2、原由
一般的2分类模型,rawPrediction有2列,一列是分类为0的原始预测数值、一列是1的;
但,XGBoost 0.81 Java 开源版本有bug,二分类预测结果rawPrediction只有一列数据,是分类为1的预测数值;
而,spark内置的交叉验证源代码评估时使用的是rawPrediction列,因此对XGBoost算法产生的rawPrediction列进行下调整修改;
3、图片