Spark MLlib求解机器学习Precision, Recall, F1值 (Java代码)

Maven依赖

 <dependency>
     <groupId>org.apache.spark</groupId>
     <artifactId>spark-core_2.11</artifactId>
     <version>2.2.0</version>
 </dependency>
 <dependency>
     <groupId>org.apache.spark</groupId>
     <artifactId>spark-sql_2.11</artifactId>
     <version>2.2.0</version>
 </dependency>
 <dependency>
     <groupId>org.apache.spark</groupId>
     <artifactId>spark-mllib_2.11</artifactId>
     <version>2.2.0</version>
 </dependency>

使用的核心类

org.apache.spark.mllib.evaluation.BinaryClassificationMetrics

理论准备

在机器学习中的二分类问题,仅仅使用accuracy不足够准确和足以度量模型,尤其是当数据集正负样本不均衡时。
例子:
银行对用户分类为二类即信用好和差,来进行发放信用卡。
那么即使不做数据挖掘,而是直接判定所有用户的信用都是好的,那么accuracy也是够高的。
但是很显然,银行对信用差的用户更加在意,需要对这少数群体更慎重划分。

Precicion, Recall, F1

首先引入混淆矩阵和True Positive(TP), False Positive(FP), False Negative(FN), True Negative(TN)。
Positive样本(对应1/0中的1)对应的是 数量少的那方,也就是上面例子中的信用差的群体。

label=1label=0
prediction=1TPFP
prediction=0FNTN

对上面的混淆矩阵解释一下:
对于TP, FP, FN, TN都是两个字母:

  • 第1个字母:是预测做的的对不对,显然当预测和label一致时是T,否认是F。
  • 第2个字母:是predict的值,预测什么值这里就是什么。

计算Precision

P r e c i s i o n = T P T P + F P Precision=\frac{TP}{TP + FP} Precision=TP+FPTP
Precision即对于表格内容的第一行而言,TP占比。

计算Recall

P r e c i s i o n = T P T P + F N Precision=\frac{TP}{TP + FN} Precision=TP+FNTP
Recall即对于表格内容的第一列而言,TP占比。

计算F1

F 1 = 2 ∗ P r e c i s i o n ∗ R e c a l l P r e c i s i o n + R e c a l l = 2 ∗ T P 2 ∗ T P + F P + F N F1=\frac{2 * Precision * Recall}{Precision + Recall} = \frac{2*TP}{2*TP + FP + FN} F1=Precision+Recall2PrecisionRecall=2TP+FP+FN2TP

注意

  1. 从上面可知,TN一直没有用到。实际上,在计算accuracy时会用到它,计算公式
    a c c u r a c y = T P + T N T P + F P + F N + T N accuracy = \frac{TP+TN}{TP+FP+FN+TN} accuracy=TP+FP+FN+TNTP+TN
  2. 一定要以数量少的那一类样本当成Positive。

Java代码

package ml;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.*;
import scala.Tuple2;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public class TestPrecisionRecallF1Java {
    public static void main(final String[] args) {
        System.out.println("test in java");

        SparkConf conf = new SparkConf().setAppName("test in java").setMaster("local");
        JavaSparkContext jsc = new JavaSparkContext(conf);
        jsc.setLogLevel("WARN");

        /** 
         * 数量
         * TP=1,   FP=0
         * FN=10,  TN=89
         */

        // schema
        List<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("pred", DataTypes.DoubleType, true));
        fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true));
        final StructType schema = DataTypes.createStructType(fields);

        SQLContext sqlContext = new SQLContext(jsc);
        final JavaRDD<Row> rowJavaRDD = sqlContext.read().csv("*此处是你的路径,数据在下面**/data.csv").javaRDD();
        // String 2 Double
        final JavaRDD<Row> rowJavaRDD1 = rowJavaRDD.mapPartitions(new FlatMapFunction<Iterator<Row>, Row>() {

            private static final long serialVersionUID = 1L;

            public Iterator<Row> call(Iterator<Row> iterator) throws Exception {
                List<Row> result = new ArrayList<Row>();
                Row row;

                while(iterator.hasNext()) {
                    row = iterator.next();
                    Double aDouble0 = Double.valueOf(row.getString(0));
                    Double aDouble1 = Double.valueOf(row.getString(1));
                    result.add(RowFactory.create(aDouble0, aDouble1));
                }
                return result.iterator();
            }
        });

        Dataset<Row> df = sqlContext.createDataFrame(rowJavaRDD1, schema);

		// 上面的都是数据准备工作。下面真正开始计算
        BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(df); // 核心。
        RDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold();  // 计算precisoin
        RDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold();  // 计算recall
        RDD<Tuple2<Object, Object>> f1 = metrics.fMeasureByThreshold();  // 计算F1

        System.out.println("precision: "+precision.toJavaRDD().collect());
        System.out.println("recall   : "+recall.toJavaRDD().collect());
        System.out.println("f1       : "+f1.toJavaRDD().collect());
    }
}

运行结果:

precision: [(1.0,1.0), (0.0,0.11)]
recall   : [(1.0,0.09090909090909091), (0.0,1.0)]
f1       : [(1.0,0.16666666666666669), (0.0,0.19819819819819817)]

只需要关注数组中的第一个Tuple2,即

precision: (1.0,1.0)  关注第2个数值,是1.0, 指的是precision=1.0
recall   : (1.0,0.09090909090909091)  关注第2个数值0.0909,recall=0.0909
f1       : (1.0,0.16666666666666669)   同理,f1=0.1666

回到发放信用卡问题上

Positive类别是信用差(对应Positive)的用户群体。当我们把大部分用户判定为信用好(对应Negative类别),就如同data.csv中,TP+TN占比90%。虽然这个accuracy很高,但是recall, F1度量值很低。
这里若使用recall或F1来度量和评定模型效果的话,会更加有说服性。

用到的数据集

data.csv。第一列是prediction,第二列是label。
从上到下是:1个TP,0个FP,10个FN,89个TN。
数据集中有11个label=1即信用差的用户,而TP=1即对信用差的用户只成功预测了1次。

1.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0

参考

https://blog.csdn.net/wo334499/article/details/51689609/

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值