Alink漫谈(八) : 二分类评估 AUC、K-S、PRC、Precision、Recall、LiftChart 如何实现

目录
Alink漫谈(八) : 二分类评估 AUC、K-S、PRC、Precision、Recall、LiftChart 如何实现
0x00 摘要
0x01 相关概念
0x02 示例代码
2.1 主要思路
0x03 批处理
3.1 EvalBinaryClassBatchOp
3.2 BaseEvalClassBatchOp
3.2.0 调用关系综述
3.2.1 calLabelPredDetailLocal
3.2.1.1 flatMap
3.2.1.2 reduceGroup
3.2.1.3 mapPartition
3.2.2 ReduceBaseMetrics
3.2.3 SaveDataAsParams
3.2.4 计算混淆矩阵
3.2.4.1 原始矩阵
3.2.4.2 计算标签
3.2.4.3 具体代码
0x04 流处理
4.1 示例
4.1.1 主类
4.1.2 TimeMemSourceStreamOp
4.1.3 Source
4.2 BaseEvalClassStreamOp
4.2.1 PredDetailLabel
4.2.2 AllDataMerge
4.2.3 SaveDataStream
4.2.4 Union
4.2.4.1 allOutput
4.2.4.2 windowOutput
0xFF 参考
0x00 摘要
Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。二分类评估是对二分类算法的预测结果进行效果评估。本文将剖析Alink中对应代码实现。

0x01 相关概念
如果对本文某些概念有疑惑,可以参见之前文章 [白话解析] 通过实例来梳理概念 :准确率 (Accuracy)、精准率(Precision)、召回率(Recall) 和 F值(F-Measure)

0x02 示例代码
public class EvalBinaryClassExample {

AlgoOperator getData(boolean isBatch) {
    Row[] rows = new Row[]{
            Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"),
            Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"),
            Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"),
            Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"),
            Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}")
    };

    String[] schema = new String[]{"label", "detailInput"};

    if (isBatch) {
        return new MemSourceBatchOp(rows, schema);
    } else {
        return new MemSourceStreamOp(rows, schema);
    }
}

public static void main(String[] args) throws Exception {
    EvalBinaryClassExample test = new EvalBinaryClassExample();
    BatchOperator batchData = (BatchOperator) test.getData(true);

    BinaryClassMetrics metrics = new EvalBinaryClassBatchOp()
            .setLabelCol("label")
            .setPredictionDetailCol("detailInput")
            .linkFrom(batchData)
            .collectMetrics();

    System.out.println("RocCurve:" + metrics.getRocCurve());
    System.out.println("AUC:" + metrics.getAuc());
    System.out.println("KS:" + metrics.getKs());
    System.out.println("PRC:" + metrics.getPrc());
    System.out.println("Accuracy:" + metrics.getAccuracy());
    System.out.println("Macro Precision:" + metrics.getMacroPrecision());
    System.out.println("Micro Recall:" + metrics.getMicroRecall());
    System.out.println("Weighted Sensitivity:" + metrics.getWeightedSensitivity());
}

}
程序输出

RocCurve:([0.0, 0.0, 0.0, 0.5, 0.5, 1.0, 1.0],[0.0, 0.3333333333333333, 0.6666666666666666, 0.6666666666666666, 1.0, 1.0, 1.0])
AUC:0.8333333333333333
KS:0.6666666666666666
PRC:0.9027777777777777
Accuracy:0.6
Macro Precision:0.3
Micro Recall:0.6
Weighted Sensitivity:0.6
在 Alink 中,二分类评估有批处理,流处理两种实现,下面一一为大家介绍( Alink 复杂之一在于大量精细的数据结构,所以下文会大量打印程序中变量以便大家理解)。

2.1 主要思路
把 [0,1] 分成假设 100000个桶(bin)。所以得到positiveBin / negativeBin 两个100000的数组。

根据输入给positiveBin / negativeBin赋值。positiveBin就是 TP + FP,negativeBin就是 TN + FN。这些是后续计算的基础。

遍历bins中每一个有意义的点,计算出totalTrue和totalFalse,并且在每一个点上计算该点的混淆矩阵,tpr,以及rocCurve,recallPrecisionCurve,liftChart在该点对应的数据;

依据曲线内容计算并且存储 AUC/PRC/KS

具体后续还有详细调用关系综述。

0x03 批处理
3.1 EvalBinaryClassBatchOp
EvalBinaryClassBatchOp是二分类评估的实现,功能是计算二分类的评估指标(evaluation metrics)。

输入有两种:

label column and predResult column
label column and predDetail column。如果有predDetail,则predResult被忽略
我们例子中 “prefix1” 就是 label,"{“prefix1”: 0.9, “prefix0”: 0.1}" 就是 predDetail

Row.of(“prefix1”, “{“prefix1”: 0.9, “prefix0”: 0.1}”)
具体类摘录如下:

public class EvalBinaryClassBatchOp extends BaseEvalClassBatchOp implements BinaryEvaluationParams , EvaluationMetricsCollector {

@Override
public BinaryClassMetrics collectMetrics() {
	return new BinaryClassMetrics(this.collect().get(0));
}  

}
可以看到,其主要工作都是在基类BaseEvalClassBatchOp中完成,所以我们会首先看BaseEvalClassBatchOp。

3.2 BaseEvalClassBatchOp
我们还是从 linkFrom 函数入手,其主要是做了几件事:

获取配置信息
从输入中提取某些列:“label”,“detailInput”
calLabelPredDetailLocal会按照partition分别计算evaluation metrics
综合reduce上述计算结果
SaveDataAsParams函数会把最终数值输入到 output table
具体代码如下

@Override
public T linkFrom(BatchOperator<?>… inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
String labelColName = this.get(MultiEvaluationParams.LABEL_COL);
String positiveValue = this.get(BinaryEvaluationParams.POS_LABEL_VAL_STR);

// Judge the evaluation type from params.
ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(this.getParams());

DataSet<BaseMetricsSummary> res;
switch (type) {
    case PRED_DETAIL: {
        String predDetailColName = this.get(MultiEvaluationParams.PREDICTION_DETAIL_COL);
        // 从输入中提取某些列:"label","detailInput" 
        DataSet<Row> data = in.select(new String[] {labelColName, predDetailColName}).getDataSet();
        // 按照partition分别计算evaluation metrics
        res = calLabelPredDetailLocal(data, positiveValue, binary);
        break;
    }
    ......
}

// 综合reduce上述计算结果
DataSet<BaseMetricsSummary> metrics = res
    .reduce(new EvaluationUtil.ReduceBaseMetrics());

// 把最终数值输入到 output table
this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()),
    new String[] {DATA_OUTPUT}, new TypeInformation[] {Types.STRING});

return (T)this;

}

// 执行中一些变量如下
labelColName = “label”
predDetailColName = “detailInput”
type = {ClassificationEvaluationUtil$Type@2532} “PRED_DETAIL”
binary = true
positiveValue = null
3.2.0 调用关系综述
因为后续代码调用关系复杂,所以先给出一个调用关系:

从输入中提取某些列:“label”,“detailInput”,in.select(new String[] {labelColName, predDetailColName}).getDataSet()。因为可能输入还有其他列,而只有某些列是我们计算需要的,所以只提取这些列。
按照partition分别计算evaluation metrics,即调用 calLabelPredDetailLocal(data, positiveValue, binary);
flatMap会从label列和prediction列中,取出所有labels(注意是取出labels的名字 ),发送给下游算子。
reduceGroup主要功能是通过 buildLabelIndexLabelArray 去重 “labels名字”,然后给每一个label一个ID,得到一个 <labels, ID>的map,最后返回是二元组(map, labels),即({prefix1=0, prefix0=1},[prefix1, prefix0])。从后文看,<labels, ID>Map看来是多分类才用到。二分类只用到了labels。
mapPartition 分区调用 CalLabelDetailLocal 来计算混淆矩阵,主要是分区调用getDetailStatistics,前文中得到的二元组(map, labels)会作为参数传递进来 。
getDetailStatistics 遍历 rows 数据,提取每一个item(比如 “prefix1,{“prefix1”: 0.8, “prefix0”: 0.2}”),然后通过updateBinaryMetricsSummary累积计算混淆矩阵所需数据。
updateBinaryMetricsSummary 把 [0,1] 分成假设 100000个桶(bin)。所以得到positiveBin / negativeBin 两个100000的数组。positiveBin就是 TP + FP,negativeBin就是 TN + FN。
如果某个 sample 为 正例 (positive value) 的概率是 p, 则该 sample 对应的 bin index 就是 p * 100000。如果 p 被预测为正例 (positive value) ,则positiveBin[index]++,
否则就是被预测为负例(negative value) ,则negativeBin[index]++。
综合reduce上述计算结果,metrics = res.reduce(new EvaluationUtil.ReduceBaseMetrics());
具体计算是在BinaryMetricsSummary.merge,其作用就是Merge the bins, and add the logLoss。
把最终数值输入到 output table,setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()…);
归并所有BaseMetrics后,得到total BaseMetrics,计算indexes存入params。collector.collect(t.toMetrics().serialize());
实际业务在BinaryMetricsSummary.toMetrics,即基于bin的信息计算,然后存储到params。
extractMatrixThreCurve函数取出非空的bins,据此计算出ConfusionMatrix array(混淆矩阵), threshold array, rocCurve/recallPrecisionCurve/LiftChart.
遍历bins中每一个有意义的点,计算出totalTrue和totalFalse,并且在每一个点上计算:
curTrue += positiveBin[index]; curFalse += negativeBin[index];
得到该点的混淆矩阵 new ConfusionMatrix(new long[][] { {curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}});
得到 tpr = (totalTrue == 0 ? 1.0 : 1.0 * curTrue / totalTrue);
rocCurve,recallPrecisionCurve,liftChart在该点对应的数据;
依据曲线内容计算并且存储 AUC/PRC/KS
对生成的rocCurve/recallPrecisionCurve/LiftChart输出进行抽样
依据抽样后的输出存储 RocCurve/RecallPrecisionCurve/LiftChar
存储正例样本的度量指标
存储Logloss
Pick the middle point where threshold is 0.5.
3.2.1 calLabelPredDetailLocal
本函数按照partition分别计算评估指标 evaluation metrics。是的,这代码很短,但是有个地方需要注意。有时候越简单的地方越容易疏漏。容易疏漏点是:

第一行代码的结果 labels 是第二行代码的参数,而并非第二行主体。第二行代码主体和第一行代码主体一样,都是data。

private static DataSet calLabelPredDetailLocal(DataSet data, final String positiveValue, oolean binary) {

DataSet<Tuple2<Map<String, Integer>, String[]>> labels = data.flatMap(new FlatMapFunction<Row, String>() {
    @Override
    public void flatMap(Row row, Collector<String> collector) {
        TreeMap<String, Double> labelProbMap;
        if (EvaluationUtil.checkRowFieldNotNull(row)) {
            labelProbMap = EvaluationUtil.extractLabelProbMap(row);
            labelProbMap.keySet().forEach(collector::collect);
            collector.collect(row.getField(0).toString());
        }
    }
})
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值