问题描述
在二分类问题中,评测某种算法有很多指标,很多论文是通过比较 F1-score 来证明自己的算法是可行的。best F1-score 是指在不确定阈值的情况下,如何找到最合适的阈值,使得 F1-score 值最大。
F1-score 计算方法
TP/FP/TN/FN
全称 | 真实值(标签,label) | 预测值(predict) | |
---|---|---|---|
TP | True Positive | 1 | 1 |
FP | False Positive | 0 | 1 |
TN | True Negative | 0 | 0 |
FN | False Negative | 1 | 0 |
Positive 与 Negative 是指预测,True 与 False 是指预测与结果是否一致。
精准度(precision)
p r e c i s i o n = T P T P + F P precision = \frac{TP}{TP+FP} precision=TP+FPTP
召回率(recall)
r e c a l l = T P T P + F N recall = \frac{TP}{TP+FN} recall=TP+FNTP
F1-score
F1-score = 2 ∗ r e c a l l ∗ p r e c i s i o n r e c a l l + p r e c i s i o n \text{F1-score} = \frac{2*recall*precision}{recall+precision} F1-score=recall+precision2∗recall∗precision
当然,如果感兴趣的话可以代入求解
F1-score = 2 ∗ T P T P + F P ∗ T P T P + F N T P T P + F P + T P T P + F N = 2 ∗ T P 2 ∗ T P + F P + F N \text{F1-score} = \frac{2*\frac{TP}{TP+FP}*\frac{TP}{TP+FN}}{\frac{TP}{TP+FP}+\frac{TP}{TP+FN}} \\ =\frac{2*TP}{2*TP+FP+FN} F1-score=TP+FPTP+TP+FNTP2∗TP+FPTP∗TP+FNTP=2∗TP+FP+FN2∗TP
根据阈值打标
很多算法进行二分类时,返回的是二分类的概率值,然后根据阈值来确定具体分类。
一般情况下,都是通过比较大小关系而进行标记。换句话说,对于概率值序列 S = { a 0 , a 1 , . . . , a n − 1 } S=\{a_0,a_1,...,a_{n-1}\} S={a0,a1,...,an−1},需要找到最好的阈值 α \alpha α ,如果 a i ≥ α a_i \ge \alpha ai≥α 则, p r e d i = 1 pred_i = 1 predi=1。一般而言, α ∈ S \alpha \in S α∈S。
找到最合适的阈值
直接使用 sklearn 提供的方法比较简单,但这里先简单介绍一下基本原理。
计算方法也非常简单粗暴,直接把可能阈值全部计算一遍,得到一个 F1-score 数组,然后找到最大值以及对一个的阈值即可。
from sklearn.metrics import precision_recall_curve
import numpy as np
predict = [0.1, 0.2, 0.3, 0.4, 0.5,
0.5, 0.6, 0.6, 0.7, 0.7,
0.8, 0.8, 0.8, 0.8, 0.8,
0.8, 0.9, 0.9, 0.9, 0.9]
label = [0, 0, 0, 0, 0,
1, 1, 0, 1, 1,
1, 1, 1, 1, 1,
1, 1, 1, 1, 1]
precisions, recalls, thresholds = precision_recall_curve(label,predict)
# 拿到最优结果以及索引
f1_scores = (2 * precisions * recalls) / (precisions + recalls)
best_f1_score = np.max(f1_scores[np.isfinite(f1_scores)])
best_f1_score_index = np.argmax(f1_scores[np.isfinite(f1_scores)])
# 阈值
best_f1_score, thresholds[best_f1_score_index]
输出的内容为:
(0.9333333333333333, 0.5)
Smileyan
2021.7.9 18:28