0x00 摘要
Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。二分类评估是对二分类算法的预测结果进行效果评估。本文将剖析Alink中对应代码实现
public class EvalBinaryClassExample {
AlgoOperator getData(boolean isBatch) {
Row[] rows = new Row[]{
Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"),
Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"),
Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"),
Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"),
Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}")
};
String[] schema = new String[]{"label", "detailInput"};
if (isBatch) {
return new MemSourceBatchOp(rows, schema);
} else {
return new MemSourceStreamOp(rows, schema);
}
}
public static void main(String[] args) throws Exception {
EvalBinaryClassExample test = new EvalBinaryClassExample();
BatchOperator batchData = (BatchOperator) test.getData(true);
BinaryClassMetrics metrics = new EvalBinaryClassBatchOp()
.setLabelCol("label")
.setPredictionDetailCol("detailInput")
.linkFrom(batchData)
.collectMetrics();
System.out.println("RocCurve:" + metrics.getRocCurve());
System.out.println("AUC:" + metrics.getAuc());
System.out.println("KS:" + metrics.getKs());
System.out.println("PRC:" + metrics.getPrc());
System.out.println("Accuracy:" + metrics.getAccuracy());
System.out.println("Macro Precision:" + metrics.getMacroPrecision());
System.out.println("Micro Recall:" + metrics.getMicroRecall());
System.out.println("Weighted Sensitivity:" + metrics.getWeightedSensitivity());
}
}
程序输出
RocCurve:([0.0, 0.0, 0.0, 0.5, 0.5, 1.0, 1.0],[0.0, 0.3333333333333333, 0.6666666666666666, 0.6666666666666666, 1.0, 1.0, 1.0])
AUC:0.8333333333333333
KS:0.6666666666666666
PRC:0.9027777777777777
Accuracy:0.6
Macro Precision:0.3
Micro Recall:0.6
Weighted Sensitivity:0.6
在 Alink 中,二分类评估有批处理,流处理两种实现,下面一一为大家介绍( Alink 复杂之一在于大量精细的数据结构,所以下文会大量打印程序中变量以便大家理解)。
2.1 主要思路
-
把 [0,1] 分成假设 100000个桶(bin)。所以得到positiveBin / negativeBin 两个100000的数组。
-
根据输入给positiveBin / negativeBin赋值。positiveBin就是 TP + FP,negativeBin就是 TN + FN。这些是后续计算的基础。
-
遍历bins中每一个有意义的点,计算出totalTrue和totalFalse,并且在每一个点上计算该点的混淆矩阵,tpr,以及rocCurve,recallPrecisionCurve,liftChart在该点对应的数据;
-
依据曲线内容计算并且存储 AUC/PRC/KS
具体后续还有详细调用关系综述。
0x03 批处理
3.1 EvalBinaryClassBatchOp
EvalBinaryClassBatchOp是二分类评估的实现,功能是计算二分类的评估指标(evaluation metrics)。
输入有两种:
- label column and predResult column
- label column and predDetail column。如果有predDetail,则predResult被忽略
我们例子中 "prefix1"
就是 label,"{\"prefix1\": 0.9, \"prefix0\": 0.1}"
就是 predDetail
Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}")
具体类摘录如下:
public class EvalBinaryClassBatchOp extends BaseEvalClassBatchOp<EvalBinaryClassBatchOp> implements BinaryEvaluationParams <EvalBinaryClassBatchOp>, EvaluationMetricsCollector<BinaryClassMetrics> {
@Override
public BinaryClassMetrics collectMetrics() {
return new BinaryClassMetrics(this.collect().get(0));
}
}
可以看到,其主要工作都是在基类BaseEvalClassBatchOp中完成,所以我们会首先看BaseEvalClassBatchOp。
3.2 BaseEvalClassBatchOp
我们还是从 linkFrom 函数入手,其主要是做了几件事:
- 获取配置信息
- 从输入中提取某些列:"label","detailInput"
- calLabelPredDetailLocal会按照partition分别计算evaluation metrics
- 综合reduce上述计算结果
- SaveDataAsParams函数会把最终数值输入到 output table
具体代码如下
@Override
public T linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
String labelColName = this.get(MultiEvaluationParams.LABEL_COL);
String positiveValue = this.get(BinaryEvaluationParams.POS_LABEL_VAL_STR);
// Judge the evaluation type from params.
ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(this.getParams());
DataSet<BaseMetricsSummary> res;
switch (type) {
case PRED_DETAIL: {
String predDetailColName = this.get(MultiEvaluationParams.PREDICTION_DETAIL_COL);
// 从输入中提取某些列:"label","detailInput"
DataSet<Row> data = in.select(new String[] {labelColName, predDetailColName}).getDataSet();
// 按照partition分别计算evaluation metrics
res = calLabelPredDetailLocal(data, positiveValue, binary);
break;
}
......
}
// 综合reduce上述计算结果
DataSet<BaseMetricsSummary> metrics = res
.reduce(new EvaluationUtil.ReduceBaseMetrics());
// 把最终数值输入到 output table
this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()),
new String[] {DATA_OUTPUT}, new TypeInformation[] {Types.STRING});
return (T)this;
}
// 执行中一些变量如下
labelColName = "label"
predDetailColName = "detailInput"
type = {ClassificationEvaluationUtil$Type@2532} "PRED_DETAIL"
binary = true
positiveValue = null
3.2.0 调用关系综述
因为后续代码调用关系复杂,所以先给出一个调用关系:
- 从输入中提取某些列:"label","detailInput",in.select(new String[] {labelColName, predDetailColName}).getDataSet()。因为可能输入还有其他列,而只有某些列是我们计算需要的,所以只提取这些列。
- 按照partition分别计算evaluation metrics,即调用 calLabelPredDetailLocal(data, positiveValue, binary);
- flatMap会从label列和prediction列中,取出所有labels(注意是取出labels的名字 ),发送给下游算子。
- reduceGroup主要功能是通过 buildLabelIndexLabelArray 去重 "labels名字",然后给每一个label一个ID,得到一个 <labels, ID>的map,最后返回是二元组(map, labels),即({prefix1=0, prefix0=1},[prefix1, prefix0])。从后文看,<labels, ID>Map看来是多分类才用到。二分类只用到了labels。
- mapPartition 分区调用 CalLabelDetailLocal 来计算混淆矩阵,主要是分区调用getDetailStatistics,前文中得到的二元组(map, labels)会作为参数传递进来 。
- getDetailStatistics 遍历 rows 数据,提取每一个item(比如 "prefix1,{"prefix1": 0.8, "prefix0": 0.2}"),然后通过updateBinaryMetricsSummary累积计算混淆矩阵所需数据。
- updateBinaryMetricsSummary 把 [0,1] 分成假设 100000个桶(bin)。所以得到positiveBin / negativeBin 两个100000的数组。positiveBin就是 TP + FP,negativeBin就是 TN + FN。
- 如果某个 sample 为 正例 (positive value) 的概率是 p, 则该 sample 对应的 bin index 就是 p * 100000。如果 p 被预测为正例 (positive value) ,则positiveBin[index]++,
- 否则就是被预测为负例(negative value) ,则negativeBin[index]++。
- updateBinaryMetricsSummary 把 [0,1] 分成假设 100000个桶(bin)。所以得到positiveBin / negativeBin 两个100000的数组。positiveBin就是 TP + FP,negativeBin就是 TN + FN。
- getDetailStatistics 遍历 rows 数据,提取每一个item(比如 "prefix1,{"prefix1": 0.8, "prefix0": 0.2}"),然后通过updateBinaryMetricsSummary累积计算混淆矩阵所需数据。
- 综合reduce上述计算结果,metrics = res.reduce(new EvaluationUtil.ReduceBaseMetrics());
- 具体计算是在BinaryMetricsSummary.merge,其作用就是Merge the bins, and add the logLoss。
- 把最终数值输入到 output table,setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()..);
- 归并所有BaseMetrics后,得到total BaseMetrics,计算indexes存入params。collector.collect(t.toMetrics().serialize());
- 实际业务在BinaryMetricsSummary.toMetrics,即基于bin的信息计算,然后存储到params。
- extractMatrixThreCurve函数取出非空的bins,据此计算出ConfusionMatrix array(混淆矩阵), threshold array, rocCurve/recallPrecisionCurve/LiftChart.
- 遍历bins中每一个有意义的点,计算出totalTrue和totalFalse,并且在每一个点上计算:
- curTrue += positiveBin[index]; curFalse += negativeBin[index];
- 得到该点的混淆矩阵 new ConfusionMatrix(new long[][] { {curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}});
- 得到 tpr = (totalTrue == 0 ? 1.0 : 1.0 * curTrue / totalTrue);
- rocCurve,recallPrecisionCurve,liftChart在该点对应的数据;
- 依据曲线内容计算并且存储 AUC/PRC/KS
- 对生成的rocCurve/recallPrecisionCurve/LiftChart输出进行抽样
- 依据抽样后的输出存储 RocCurve/RecallPrecisionCurve/LiftChar
- 存储正例样本的度量指标
- 存储Logloss
- Pick the middle point where threshold is 0.5.
- extractMatrixThreCurve函数取出非空的bins,据此计算出ConfusionMatrix array(混淆矩阵), threshold array, rocCurve/recallPrecisionCurve/LiftChart.
- 实际业务在BinaryMetricsSummary.toMetrics,即基于bin的信息计算,然后存储到params。
- 归并所有BaseMetrics后,得到total BaseMetrics,计算indexes存入params。collector.collect(t.toMetrics().serialize());
3.2.1 calLabelPredDetailLocal
本函数按照partition分别计算评估指标 evaluation metrics。是的,这代码很短,但是有个地方需要注意。有时候越简单的地方越容易疏漏。容易疏漏点是:
第一行代码的结果 labels 是第二行代码的参数,而并非第二行主体。第二行代码主体和第一行代码主体一样,都是data。
private static DataSet<BaseMetricsSummary> calLabelPredDetailLocal(DataSet<Row> data, final String positiveValue, oolean binary) {
DataSet<Tuple2<Map<String, Integer>, String[]>> labels = data.flatMap(new FlatMapFunction<Row, String>() {
@Override
public void flatMap(Row row, Collector<String> collector) {
TreeMap<String, Double> labelProbMap;
if (EvaluationUtil.checkRowFieldNotNull(row)) {
labelProbMap = EvaluationUtil.extractLabelProbMap(row);
labelProbMap.keySet().forEach(collector::collect);
collector.collect(row.getField(0).toString());
}
}
}).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(binary, positiveValue));
return data
.rebalance()
.mapPartition(new CalLabelDetailLocal(binary))
.withBroadcastSet(labels, LABELS);
}
calLabelPredDetailLocal中具体分为三步骤:
- 在