在此记录一下分类模型常用的指标,这里不区分二分类和多分类,实际上本人觉得不需要分这么清,无非是想找一个指标或者标准去衡量分类模型的好而已。
常见评估指标
- 准确度 (accurary)关心全局的分类,所有类别预测对的样本个数除于总样本个数。
- 精确度(precision)只关心预测的那一个分类,预测对的样本个数除于预测该分类的总数。
- 召回率(recall)只关心预测的那一个分类,预测对的样本个数除于该分类的总数。
- F1-score 是精确度和召回率的调和平均值,计算公式: 1 F 1 = 1 2 ( 1 p r e c i s i o n + 1 r e c a l l ) \frac{1}{F1} = \frac{1}{2}(\frac{1}{precision} + \frac{1}{recall}) F11=21(precision1+recall1)
注意:
- 在刚接触这些概念的时候不要试图用那些什么TP、FN的公式去记忆,这是一件很绕的事情,通过文字完全知道这几个概念想表达什么。
- 精确度和召回率都是直关心某一个分类的结果,两者分子相同,精确度的分母是不确定的,召回率的分母是确定的(某一类标签的样本数)。
说明
假设有10个类别的多分类任务,每个类别有100个测试样本,共1000个样本。现讨论类别为2的样本计算:
- 准确度 是全局考虑的,模型预测了890个样本的标签和真实标签都一致,则accurary=890/1000。
- 精准度 模型预测标签2的样本有150个(也可以小于100个),其中关心的类别2预测对了80个,则precision=80/150。(分母是不固定的)
- 召回率 只看关心的类别2,100个样本中模型成功预测了67个样本,则recall=67/100。(分母固定)
- F1-score 类别2的F1值为 F1 = 2(precision * recall)/(precision + recall)。
为什么要有这么多的评估指标呢?
- 对于一个不对称的测试集中准确度高有时可能不是我们想要的,假设有个癌症评估模型,那么正常的数据可能会占绝大部分,1000份测试样本中900个是正常的,100个异常的,此时模型只需要把那900个样本预测出来,那么模型的准确度也是非常高的,但是没有任何用。
- 精准度个人理解是允许模型犯一点点错误,但是要尽可能的把对的样本都找出来。
- 召回率个人理解是严格要求模型不要犯错误,在不犯错误的前提下尽可能的把对的样本都找出来。
- 精准度和召回率都有对应不同的应用场景,可见下方知乎高赞参考链接。
F1 score可视化
F1计算公式如下:
1 F 1 = 1 2 ( 1 p r e c i s i o n + 1 r e c a l l ) F 1 = 2 ⋅ p r e c i s i o n ⋅ r e c a l l p r e c i s i o n + r e c a l l \begin{equation} \begin{split} \frac{1}{F1} &= \frac{1}{2}(\frac{1}{precision} + \frac{1}{recall}) \\ F1 &= \frac{2\cdot precision \cdot recall}{precision + recall} \end{split} \end{equation} F11F1=21(precision1+recall1)=precision+recall2⋅precision⋅recall
实际上precision和recall这两个变量在数学的形式上是等价的,先固定precision=0.5,可以看到F1随着recall从0.2到0.9的变化曲线是一个非线性递增的,由于precision的影响F1并不会变化的很快,在recall=0.899时候,F1的值也才0.6412不是很好。
代码实现
import matplotlib.pyplot as plt
import numpy as np
import math
precision = 0.5
recall = np.linspace(0.2, 0.9, 100)
def f1(precision, recall):
return 2 * precision * recall / (precision + recall)
plt.plot(recall, f1(precision, recall))
plt.show()
当precision和recall都在逐渐增加大时,两者对F1的影响如下,只有precision和recall都在较大的范围时,F1的值才会比较高。
代码实现
import matplotlib.pyplot as plt
import numpy as np
import math
precision = np.linspace(0.5, 0.99, 100)
recall = np.linspace(0.2, 1, 100)
def f1(precision, recall):
x1 = 2 * precision * recall
x2 = precision + recall
y = []
for i in range(len(x1)):
y.append(x1[i] / x2[i])
return np.array(y)
plt.plot(recall, f1(precision, recall))
plt.show()
混淆矩阵 confusion matrix
混淆矩阵其实就是对每一类样本数据进行可视化,只不过是方阵的表示形式,下图是15个类型的混淆矩阵,每个类别的样本数由support给出。
confusion 从15*15的方阵,我们看下该如何解读。
从行来解读看是当前类别预测的情况,比如第一行数字相加是该类别的测试总样本,其中A类别共有34个样本其中预测为A的分类有33个,1个错预测成了C类。那么recall = 33 / 33+1=0.9705882352941176和统计信息的A类中的recall是一致的。
从列来解读是预测成该类别的情况,比如我们看第4列,预测成D类别的总数是22+1+1=24个样本,其中预测对的样本是22个,则precision=22 / 22+1+1=0.9166666666666666和统计信息D类别的precision是一致的。
可以总结出confusion matrix 行是召回率信息,列是精确率信息。
如果某一行的两个非零数据相差不大,可以明显知道这两个类别的区分度不是很高(可能要从训练集、网络结构、参数权重考虑优化)。
如果某一列的非零数据很多,可以知道模型本身太容易犯错了,模型开始四亲不认了,什么数据都预测成了该列所对应的类别了(这是有一定业务场景的,比如二分类的垃圾邮件过滤中,宁愿让模型犯一点错误将垃圾邮件分类成正常邮件,也不愿将女神发来的约会邮件当成垃圾邮件处理了)至此confusion matrix应该是比较清晰了。
总结
- 准确度 考虑全局所有样本,预测对的除以总的样本数,比较直接。
- 精确度 允许模型犯点错误,针对某一类别预测对的除以总的预测数。
- 召回率 严格模型尽量不犯错,针对某一类别预测对的除以该类别测试数。
- F1 精确度和召回率调和平均值,两者都高的情况下F1才会高。
- 混淆矩阵 类别组成的方阵 行是召回率分析,列是精确率分析。
- 遗留混淆矩阵中的macro avg和weighted avg参数分析和PR曲线、ROC曲线,未完待续~
参考文章
https://zhuanlan.zhihu.com/p/147663370