Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新

Alink 是阿里巴巴基于 Flink 的机器学习平台,支持批处理和流处理的二分类评估。本文深入解析二分类评估的批处理组件 EvalBinaryClassBatchOp 和 BaseEvalClassBatchOp,以及流处理组件的工作原理,涉及混淆矩阵、ROC 曲线等关键指标的计算过程。
摘要由CSDN通过智能技术生成

0x00 摘要

Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。二分类评估是对二分类算法的预测结果进行效果评估。本文将剖析Alink中对应代码实现

public class EvalBinaryClassExample {

    AlgoOperator getData(boolean isBatch) {
        Row[] rows = new Row[]{
                Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"),
                Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"),
                Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"),
                Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"),
                Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}")
        };

        String[] schema = new String[]{"label", "detailInput"};

        if (isBatch) {
            return new MemSourceBatchOp(rows, schema);
        } else {
            return new MemSourceStreamOp(rows, schema);
        }
    }

    public static void main(String[] args) throws Exception {
        EvalBinaryClassExample test = new EvalBinaryClassExample();
        BatchOperator batchData = (BatchOperator) test.getData(true);

        BinaryClassMetrics metrics = new EvalBinaryClassBatchOp()
                .setLabelCol("label")
                .setPredictionDetailCol("detailInput")
                .linkFrom(batchData)
                .collectMetrics();

        System.out.println("RocCurve:" + metrics.getRocCurve());
        System.out.println("AUC:" + metrics.getAuc());
        System.out.println("KS:" + metrics.getKs());
        System.out.println("PRC:" + metrics.getPrc());
        System.out.println("Accuracy:" + metrics.getAccuracy());
        System.out.println("Macro Precision:" + metrics.getMacroPrecision());
        System.out.println("Micro Recall:" + metrics.getMicroRecall());
        System.out.println("Weighted Sensitivity:" + metrics.getWeightedSensitivity());
    }
}

程序输出

RocCurve:([0.0, 0.0, 0.0, 0.5, 0.5, 1.0, 1.0],[0.0, 0.3333333333333333, 0.6666666666666666, 0.6666666666666666, 1.0, 1.0, 1.0])
AUC:0.8333333333333333
KS:0.6666666666666666
PRC:0.9027777777777777
Accuracy:0.6
Macro Precision:0.3
Micro Recall:0.6
Weighted Sensitivity:0.6

在 Alink 中,二分类评估有批处理,流处理两种实现,下面一一为大家介绍( Alink 复杂之一在于大量精细的数据结构,所以下文会大量打印程序中变量以便大家理解)。

2.1 主要思路

  • 把 [0,1] 分成假设 100000个桶(bin)。所以得到positiveBin / negativeBin 两个100000的数组。

  • 根据输入给positiveBin / negativeBin赋值。positiveBin就是 TP + FP,negativeBin就是 TN + FN。这些是后续计算的基础。

  • 遍历bins中每一个有意义的点,计算出totalTrue和totalFalse,并且在每一个点上计算该点的混淆矩阵,tpr,以及rocCurve,recallPrecisionCurve,liftChart在该点对应的数据;

  • 依据曲线内容计算并且存储 AUC/PRC/KS

具体后续还有详细调用关系综述。

0x03 批处理

3.1 EvalBinaryClassBatchOp

EvalBinaryClassBatchOp是二分类评估的实现,功能是计算二分类的评估指标(evaluation metrics)。

输入有两种:

  • label column and predResult column
  • label column and predDetail column。如果有predDetail,则predResult被忽略

我们例子中 "prefix1" 就是 label,"{\"prefix1\": 0.9, \"prefix0\": 0.1}" 就是 predDetail

Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}")

具体类摘录如下:

public class EvalBinaryClassBatchOp extends BaseEvalClassBatchOp<EvalBinaryClassBatchOp> implements BinaryEvaluationParams <EvalBinaryClassBatchOp>, EvaluationMetricsCollector<BinaryClassMetrics> {
  
	@Override
	public BinaryClassMetrics collectMetrics() {
		return new BinaryClassMetrics(this.collect().get(0));
	}  
}

可以看到,其主要工作都是在基类BaseEvalClassBatchOp中完成,所以我们会首先看BaseEvalClassBatchOp。

3.2 BaseEvalClassBatchOp

我们还是从 linkFrom 函数入手,其主要是做了几件事:

  • 获取配置信息
  • 从输入中提取某些列:"label","detailInput"
  • calLabelPredDetailLocal会按照partition分别计算evaluation metrics
  • 综合reduce上述计算结果
  • SaveDataAsParams函数会把最终数值输入到 output table

具体代码如下

@Override
public T linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    String labelColName = this.get(MultiEvaluationParams.LABEL_COL);
    String positiveValue = this.get(BinaryEvaluationParams.POS_LABEL_VAL_STR);

    // Judge the evaluation type from params.
    ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(this.getParams());

    DataSet<BaseMetricsSummary> res;
    switch (type) {
        case PRED_DETAIL: {
            String predDetailColName = this.get(MultiEvaluationParams.PREDICTION_DETAIL_COL);
            // 从输入中提取某些列:"label","detailInput" 
            DataSet<Row> data = in.select(new String[] {labelColName, predDetailColName}).getDataSet();
            // 按照partition分别计算evaluation metrics
            res = calLabelPredDetailLocal(data, positiveValue, binary);
            break;
        }
        ......
    }

    // 综合reduce上述计算结果
    DataSet<BaseMetricsSummary> metrics = res
        .reduce(new EvaluationUtil.ReduceBaseMetrics());

    // 把最终数值输入到 output table
    this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()),
        new String[] {DATA_OUTPUT}, new TypeInformation[] {Types.STRING});

    return (T)this;
}

// 执行中一些变量如下
labelColName = "label"
predDetailColName = "detailInput"  
type = {ClassificationEvaluationUtil$Type@2532} "PRED_DETAIL"
binary = true
positiveValue = null  

3.2.0 调用关系综述

因为后续代码调用关系复杂,所以先给出一个调用关系

  • 从输入中提取某些列:"label","detailInput",in.select(new String[] {labelColName, predDetailColName}).getDataSet()。因为可能输入还有其他列,而只有某些列是我们计算需要的,所以只提取这些列。
  • 按照partition分别计算evaluation metrics,即调用 calLabelPredDetailLocal(data, positiveValue, binary);
    • flatMap会从label列和prediction列中,取出所有labels(注意是取出labels的名字 ),发送给下游算子。
    • reduceGroup主要功能是通过 buildLabelIndexLabelArray 去重 "labels名字",然后给每一个label一个ID,得到一个 <labels, ID>的map,最后返回是二元组(map, labels),即({prefix1=0, prefix0=1},[prefix1, prefix0])。从后文看,<labels, ID>Map看来是多分类才用到。二分类只用到了labels。
    • mapPartition 分区调用 CalLabelDetailLocal 来计算混淆矩阵,主要是分区调用getDetailStatistics,前文中得到的二元组(map, labels)会作为参数传递进来 。
      • getDetailStatistics 遍历 rows 数据,提取每一个item(比如 "prefix1,{"prefix1": 0.8, "prefix0": 0.2}"),然后通过updateBinaryMetricsSummary累积计算混淆矩阵所需数据。
        • updateBinaryMetricsSummary 把 [0,1] 分成假设 100000个桶(bin)。所以得到positiveBin / negativeBin 两个100000的数组。positiveBin就是 TP + FP,negativeBin就是 TN + FN。
          • 如果某个 sample 为 正例 (positive value) 的概率是 p, 则该 sample 对应的 bin index 就是 p * 100000。如果 p 被预测为正例 (positive value) ,则positiveBin[index]++,
          • 否则就是被预测为负例(negative value) ,则negativeBin[index]++。
  • 综合reduce上述计算结果,metrics = res.reduce(new EvaluationUtil.ReduceBaseMetrics());
    • 具体计算是在BinaryMetricsSummary.merge,其作用就是Merge the bins, and add the logLoss。
  • 把最终数值输入到 output table,setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()..);
    • 归并所有BaseMetrics后,得到total BaseMetrics,计算indexes存入params。collector.collect(t.toMetrics().serialize());
      • 实际业务在BinaryMetricsSummary.toMetrics,即基于bin的信息计算,然后存储到params。
        • extractMatrixThreCurve函数取出非空的bins,据此计算出ConfusionMatrix array(混淆矩阵), threshold array, rocCurve/recallPrecisionCurve/LiftChart.
          • 遍历bins中每一个有意义的点,计算出totalTrue和totalFalse,并且在每一个点上计算:
          • curTrue += positiveBin[index]; curFalse += negativeBin[index];
          • 得到该点的混淆矩阵 new ConfusionMatrix(new long[][] { {curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}});
          • 得到 tpr = (totalTrue == 0 ? 1.0 : 1.0 * curTrue / totalTrue);
          • rocCurve,recallPrecisionCurve,liftChart在该点对应的数据;
        • 依据曲线内容计算并且存储 AUC/PRC/KS
        • 对生成的rocCurve/recallPrecisionCurve/LiftChart输出进行抽样
        • 依据抽样后的输出存储 RocCurve/RecallPrecisionCurve/LiftChar
        • 存储正例样本的度量指标
        • 存储Logloss
        • Pick the middle point where threshold is 0.5.

3.2.1 calLabelPredDetailLocal

本函数按照partition分别计算评估指标 evaluation metrics。是的,这代码很短,但是有个地方需要注意。有时候越简单的地方越容易疏漏。容易疏漏点是:

第一行代码的结果 labels 是第二行代码的参数,而并非第二行主体。第二行代码主体和第一行代码主体一样,都是data。

private static DataSet<BaseMetricsSummary> calLabelPredDetailLocal(DataSet<Row> data, final String positiveValue, oolean binary) {
  
    DataSet<Tuple2<Map<String, Integer>, String[]>> labels = data.flatMap(new FlatMapFunction<Row, String>() {
        @Override
        public void flatMap(Row row, Collector<String> collector) {
            TreeMap<String, Double> labelProbMap;
            if (EvaluationUtil.checkRowFieldNotNull(row)) {
                labelProbMap = EvaluationUtil.extractLabelProbMap(row);
                labelProbMap.keySet().forEach(collector::collect);
                collector.collect(row.getField(0).toString());
            }
        }
    }).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(binary, positiveValue));

    return data
        .rebalance()
        .mapPartition(new CalLabelDetailLocal(binary))
        .withBroadcastSet(labels, LABELS);
}

calLabelPredDetailLocal中具体分为三步骤:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值