Spark MLlib求解Precision, Recall, F1 使用Java
Maven依赖
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>2.2.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>2.2.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>2.2.0</version>
</dependency>
使用的核心类
org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
理论准备
在机器学习中的二分类问题,仅仅使用accuracy不足够准确和足以度量模型,尤其是当数据集正负样本不均衡时。
例子:
银行对用户分类为二类即信用好和差,来进行发放信用卡。
那么即使不做数据挖掘,而是直接判定所有用户的信用都是好的,那么accuracy也是够高的。
但是很显然,银行对信用差的用户更加在意,需要对这少数群体更慎重划分。
Precicion, Recall, F1
首先引入混淆矩阵和True Positive(TP), False Positive(FP), False Negative(FN), True Negative(TN)。
Positive样本(对应1/0中的1)对应的是 数量少的那方,也就是上面例子中的信用差的群体。
label=1 | label=0 | |
---|---|---|
prediction=1 | TP | FP |
prediction=0 | FN | TN |
对上面的混淆矩阵解释一下:
对于TP, FP, FN, TN都是两个字母:
- 第1个字母:是预测做的的对不对,显然当预测和label一致时是T,否认是F。
- 第2个字母:是predict的值,预测什么值这里就是什么。
计算Precision
P
r
e
c
i
s
i
o
n
=
T
P
T
P
+
F
P
Precision=\frac{TP}{TP + FP}
Precision=TP+FPTP
即对于表格内容的第一行而言,TP占比。
计算Recall
P
r
e
c
i
s
i
o
n
=
T
P
T
P
+
F
N
Precision=\frac{TP}{TP + FN}
Precision=TP+FNTP
即对于表格内容的第一列而言,TP占比。
计算F1
F 1 = 2 ∗ P r e c i s i o n ∗ R e c a l l P r e c i s i o n + R e c a l l = 2 ∗ T P 2 ∗ T P + F P + F N F1=\frac{2 * Precision * Recall}{Precision + Recall} = \frac{2*TP}{2*TP + FP + FN} F1=Precision+Recall2∗Precision∗Recall=2∗TP+FP+FN2∗TP
注意
- 从上面可知,TN一直没有用到。实际上,在计算accuracy时会用到它,计算公式
a c c u r a c y = T P + T N T P + F P + F N + T N accuracy = \frac{TP+TN}{TP+FP+FN+TN} accuracy=TP+FP+FN+TNTP+TN - 一定要以数量少的那一类样本当成Positive。
Java代码
package ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.*;
import scala.Tuple2;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
public class TestPrecisionRecallF1Java {
public static void main(final String[] args) {
System.out.println("test in java");
SparkConf conf = new SparkConf().setAppName("test in java").setMaster("local");
JavaSparkContext jsc = new JavaSparkContext(conf);
jsc.setLogLevel("WARN");
/**
* 数量
* TP=1, FP=0
* FN=10, TN=89
*/
// schema
List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("pred", DataTypes.DoubleType, true));
fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true));
final StructType schema = DataTypes.createStructType(fields);
SQLContext sqlContext = new SQLContext(jsc);
final JavaRDD<Row> rowJavaRDD = sqlContext.read().csv("*此处是你的路径,数据在下面**/data.csv").javaRDD();
// String 2 Double
final JavaRDD<Row> rowJavaRDD1 = rowJavaRDD.mapPartitions(new FlatMapFunction<Iterator<Row>, Row>() {
private static final long serialVersionUID = 1L;
public Iterator<Row> call(Iterator<Row> iterator) throws Exception {
List<Row> result = new ArrayList<Row>();
Row row;
while(iterator.hasNext()) {
row = iterator.next();
Double aDouble0 = Double.valueOf(row.getString(0));
Double aDouble1 = Double.valueOf(row.getString(1));
result.add(RowFactory.create(aDouble0, aDouble1));
}
return result.iterator();
}
});
Dataset<Row> df = sqlContext.createDataFrame(rowJavaRDD1, schema);
// 上面的都是数据准备工作。下面真正开始计算
BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(df); // 核心。
RDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold(); // 计算precisoin
RDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold(); // 计算recall
RDD<Tuple2<Object, Object>> f1 = metrics.fMeasureByThreshold(); // 计算F1
System.out.println("precision: "+precision.toJavaRDD().collect());
System.out.println("recall : "+recall.toJavaRDD().collect());
System.out.println("f1 : "+f1.toJavaRDD().collect());
}
}
运行结果:
precision: [(1.0,1.0), (0.0,0.11)]
recall : [(1.0,0.09090909090909091), (0.0,1.0)]
f1 : [(1.0,0.16666666666666669), (0.0,0.19819819819819817)]
只需要关注数组中的第一个Tuple2,即
precision: (1.0,1.0) 关注第2个数值,是1.0, 指的是precision=1.0
recall : (1.0,0.09090909090909091) 关注第2个数值0.0909,recall=0.0909
f1 : (1.0,0.16666666666666669) 同理,f1=0.1666
回到发放信用卡问题上
Positive类别是信用差(对应Positive)的用户群体。当我们把大部分用户判定为信用好(对应Negative类别),就如同data.csv中,TP+TN占比90%。虽然这个accuracy很高,但是recall, F1度量值很低。
这里若使用recall或F1来度量和评定模型效果的话,会更加有说服性。
用到的数据集
data.csv。第一列是prediction,第二列是label。
从上到下是:1个TP,0个FP,10个FN,89个TN。
数据集中有11个label=1即信用差的用户,而TP=1即对信用差的用户只成功预测了1次。
1.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 1.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
0.0, 0.0
参考
https://blog.csdn.net/wo334499/article/details/51689609/