目录
1.3.3 灵敏度(Sensitivity)/ 召回率(Recall)
0 前言
前一段时间在做一个分类的项目,主要是应用ResNet18和MobileNetV2模型对数据进行分类,前者主要是用于GPU端,后者主要用于CPU端。模型分类效果主要是通过计算混淆矩阵以及准确率、召回率和F Score来分析,下面对以上指标进行详细的介绍。
视频讲解地址:【深度学习】【机器学习】【代码讲解】分类结果分析指标和方法(混淆矩阵、TP、TN、FP、FN、精确率、召回率)(附源码)_哔哩哔哩 (゜-゜)つロ 干杯~-bilibili
1 分析指标
1.1 TP、FP、FN、 TN
我们先来了解这些缩写的全称:T——True,P——Positive,F——False,N——Negative。故,
TP:真实值为 Positive, 预测值为 Positive,即真的正例;
FP:真实值为 Negative,预测值为 Positive,即假的正例;
FN:真实值为 Positive, 预测值为 Negative,即假的负例;
TN:真实值为 Negative,预测值为 Negative,即真的负例。
其中,在数学上,FP被称为第一类型错误(Type Ⅰ Error),FN被称为第二类型错误(Type Ⅱ Error)。
1.2 混淆矩阵
我们将TP、FP、FN、 TN画在一个表格里,这个表格就是混淆矩阵。
1.2.1 二分类
对于两分类,一个样本只有两种预测结果:正例或负例。
混淆矩阵 | 预测值 | ||
Positive | Negative | ||
真实值 | Positive | TP | FN |
Negative | FP | TN |
举个例子,我们要对猫和狗的图片进行分类,那么我们可以将猫设为正例,将狗设为负例,也就是说,最终的预测结果只有两种:是猫或不是猫(是狗)。假设猫狗的图片各有50和53张,猫狗图片预测正确的分别有45和47张,则混淆矩阵如表2所示
混淆矩阵 | 预测值 | ||
猫 | 狗 | ||
真实值 | 猫 | 45 | 5 |
狗 | 6 | 47 |
1.2.2 多分类
很多时候分类任务不仅仅是简单的二分类,可能是三分类或者多分类,这时的混淆矩阵和二分类的有所不同。N分类的混淆矩阵如表3所示
混淆矩阵 | 预测值 | ||||
class 1 | class 2 | ······ | class N | ||
真实值 | class 1 | ||||
class 2 | |||||
······ | |||||
class N |
举个例子,我们要对猫、狗和猪的图片进行分类,假设猫狗猪的图片各有51、52和49张,猫图片被预测为猫狗猪的图片数分别为47、1和3张,狗图片被预测为猫狗猪的图片数分别为1、49和2张,猪图片预测为猫狗猪的图片数分别为1、0和48张,则混淆矩阵如表4所示
混淆矩阵 | 预测值 | |||
猫 | 狗 | 猪 | ||
真实值 | 猫 | 47 | 1 | 3 |
狗 | 1 | 49 | 2 | |
猪 | 1 | 0 | 48 |
这三只被错误识别成猪的十有八九是橘猫。
针对每一个类别,我们也可以将其当作是二分类来分析,即预测结果是这类别或者不是这类别,单独画出其混淆矩阵。例如对于猫类别,被预测成狗和猪的结果可以统称为不是猫,这时对于猫的混淆矩阵如表5所示
混淆矩阵 | 预测值 | ||
猫 | 不是猫 | ||
真实值 | 猫 | 47 | 4 |
不是猫 | 2 | 99 |
1.3 二级指标
二级指标主要有:准确率、精确率、召回率和特异度。
1.3.1 准确率(Accuracy)
准确率:所有分类正确的结果占总观测值的比重。准确率是针对整个模型的,计算公式是:
(1)多分类模型
ACC = 分类正确的结果 / 总观测值
例如表4,分类准确率。
在多分类模型中,对于类别k,。
例如表5,三分类中猫的分类准确度为。
(2)二分类模型
例如表2,分类准确率。
1.3.2 精确率(Precission)
精确率:在模型预测是Positive的所有结果中,模型预测对的比重,计算公式是:
在表2中,猫的分类精确率为;在表5中,三分类中猫的分类精确率为。
精确率对应着预测,简单来说是:“冤假错案”成本高,“漏网之鱼”成本低。举个例子,我们要判断邮件是否为垃圾邮件,是则True,否则False。如果一封垃圾邮件被误判断成正常邮件,那么我们可能只需要浪费几秒钟时间点开查看;但如果一封很重要的邮件被丢进垃圾箱里了,那可能会导致我们错过很重要的信息。这时候FP要尽量小,在TP不变的情况下,PPV要尽量大。
在信息检索领域,精确度也称为查准率。
1.3.3 灵敏度(Sensitivity)/ 召回率(Recall)
召回率:在真实值是Positive的所有结果中,模型预测对的比重,计算公式是:
在表2中,猫的召回率为;在表5中,三分类中猫的分类召回率为。
召回率对应着样本(真实值),要求分类结果“大而全”,注重量,简单来说是:“冤假错案”成本低,“漏网之鱼”成本高。举个例子,我们要判断某一时间是否会发生地震,是则True,否则False。如果系统预测到今天会发生地震,提前发出预警,就算最终不发生地震,民众也就浪费点时间去避难;但如果真的发生地震了而没有预测出来,那就会导致人民的生命财产受到严重的损失。
在信息检索领域,召回率也称为查全率。
1.3.4 特异度(Specificity)
特异度:在真实值是Negative的所有结果中,模型预测对的比重,计算公式是:
1.4 三级指标
1.4.1 F-measure
F-measure是Precision和Recall的加权调和平均,计算公式是:
1.4.2 F1-measure
当α=1时,
F1 Score指标的取值范围是[0, 1],F1 Score越接近于0,模型的输出结果越差;F1 Score越接近于1,模型的输出结果越好。
2 代码
代码主要是分析了ResNet18和MobileNetV2两个模型的三分类结果,统计混淆矩阵并打印输出,计算了二级和三级指标并打印输出,最后将混淆矩阵和二三级指标输出保存到Excel文件中,方便后续分析处理。
2.1 源码
import os
import numpy as np
import xlwt
import shutil
# 设置表格样式
def set_style(name, height, bold=False):
style = xlwt.XFStyle()
font = xlwt.Font()
font.name = name
font.bold = bold
font.color_index = 4
font.height = height
style.font = font
borders = xlwt.Borders()
borders.left = 1
borders.right = 1
borders.top = 1
borders.bottom = 1
borders.bottom_colour=0x3A
style.borders = borders
return style
# 写Excel
def write_excel(info_dict, res_name="./res.xls", cls_dict=""):
f = xlwt.Workbook()
for sheet_name, val_dict in info_dict.items():
sheet = f.add_sheet(sheet_name, cell_overwrite_ok=True)
# Confusion Matrix
# row0 = ["Confusion Matrix", "class 0", "class 1", ..., "class N", "Pass Rate (%)"]
# colum0 = ["class 0", "class 1", ..., "class N", "model"]
row0 = ["Confusion Matrix"]
colum0 = []
for _, cls_name in cls_dict.items():
row0.append(cls_name)
colum0.append(cls_name)
row0.append("Pass Rate (%)")
colum0.append("Model")
# first row
for i in range(0, len(row0)):
sheet.write(0, i, row0[i], set_style('Times New Roman',220,True))
# first col
for i in range(0, len(colum0)):
sheet.write(i+1, 0, colum0[i], set_style('Times New Roman',220,True))
confusion_matrix = val_dict['confusion_matrix']
for row in range(confusion_matrix.shape[0]):
for col in range(confusion_matrix.shape[1]):
sheet.write(row+1, col+1, int(confusion_matrix[row][col]), set_style('Times New Roman',220,False))
# Accuracy of each class
for row in range(confusion_matrix.shape[0]):
if sum(confusion_matrix[row])*100 == 0:
ACC = -1
else:
ACC = round(confusion_matrix[row][row]/sum(confusion_matrix[row])*100, 2)
sheet.write(row+1, confusion_matrix.shape[1]+1, ACC, set_style('Times New Roman',220,False))
# Accuracy of the model
num_correct = 0
for cls_index in range(confusion_matrix.shape[0]):
num_correct += confusion_matrix[cls_index][cls_index]
ACC_model = round(num_correct/sum(sum(confusion_matrix))*100, 2)
sheet.write(confusion_matrix.shape[0]+1, confusion_matrix.shape[1]+1, ACC_model, set_style('Times New Roman',220,False))
for i in range(1, confusion_matrix.shape[1]+1):
sheet.write(confusion_matrix.shape[0]+1, i, '', set_style('Times New Roman',220,False))
sep = 2
# Index - Accuracy (ACC), Precision (PPV), Sensitivity (Recall, TPR), Specificity (TNR), F1-Score
# first row
first_row = confusion_matrix.shape[0] + 2 + sep
# row0 = ["Index (%)", "class 0", "class 1", ..., "class N"]
# colum0 = ["Accuracy", "Precision", "Sensitivity (Recall)", "Specificity", "F1-Score"]
row0 = ["Index (%)"]
for _, cls_name in cls_dict.items():
row0.append(cls_name)
colum0 = ["Accuracy", "Precision", "Sensitivity (Recall)", "Specificity", "F1-Score"]
for i in range(0, len(row0)):
sheet.write(first_row, i, row0[i], set_style('Times New Roman',220,True))
# first col
for i in range(0, len(colum0)):
sheet.write(i+1+first_row, 0, colum0[i], set_style('Times New Roman',220,True))
index_list = val_dict['index']
for row in range(len(index_list)):
for col in range(len(index_list[row])):
sheet.write(col+1+first_row, row+1, round(index_list[row][col]*100, 2), set_style('Times New Roman',220,False))
sep = 1
# TP, TN, FP, FN of each class
# first row
first_col = confusion_matrix.shape[1] + 2 + sep
row0 = ["Positive", "Negative"]
colum0 = ["Positive", "Negative"]
for cls_index in range(len(cls_dict.keys())):
sheet.write(cls_index*len(colum0)+cls_index*2, first_col, cls_dict[cls_index], set_style('Times New Roman',220,True))
for i in range(0, len(row0)):
sheet.write(cls_index*len(colum0)+cls_index*2, i+1+first_col, row0[i], set_style('Times New Roman',220,False))
# first col
for i in range(0, len(colum0)):
sheet.write(cls_index*len(colum0)+cls_index*2+i+1, first_col, colum0[i], set_style('Times New Roman',220,False))
# value - TP, FN, FP, TN
value = val_dict[cls_dict[cls_index]]
sheet.write(cls_index*len(colum0)+cls_index*2+1, first_col+1, int(value[0]), set_style('Times New Roman',220,False))
sheet.write(cls_index*len(colum0)+cls_index*2+1, first_col+2, int(value[1]), set_style('Times New Roman',220,False))
sheet.write(cls_index*len(colum0)+cls_index*2+2, first_col+1, int(value[2]), set_style('Times New Roman',220,False))
sheet.write(cls_index*len(colum0)+cls_index*2+2, first_col+2, int(value[3]), set_style('Times New Roman',220,False))
# sheet.write(1,3,'2006/12/12')
# sheet.write_merge(6,6,1,3,'未知')#合并行单元格
# sheet.write_merge(1,2,3,3,'打游戏')#合并列单元格
# sheet.write_merge(4,5,3,3,'打篮球')
# Confusion Matrix for each class
f.save(res_name)
def print_confusion_matrix(confusion_matrix, model_name='', cls_dict=''):
num_cls = confusion_matrix.shape[0]
print('')
print('------------- ', model_name, ' Confusion Matrix -------------')
print('row: target, col: predicted')
# print('+'+'-'*47+'+')
print('+'+'-------+'*(len(cls_dict.keys())+2))
print('|'+'\t', end='')
for predict_index in range(num_cls):
# if predict_index == num_cls - 1:
# end_str = '\t' + '|' + '\n'
# else:
# end_str = '\t'
end_str = '\t'
print('|'+cls_dict[predict_index].rjust(6), end=end_str)
print('|'+'Acc'.rjust(6), end=' |\n')
print('+'+'-------+'*(len(cls_dict.keys())+2))
for target_index in range(num_cls):
print('|'+cls_dict[target_index].rjust(6), end='\t')
for predict_index in range(num_cls):
# if predict_index == num_cls - 1:
# end_str = '\t' + '|'+'\n'
# else:
# end_str = '\t'
end_str = '\t'
print('|'+str(confusion_matrix[target_index][predict_index]).rjust(6), end=end_str)
print('|'+ "{:.2f}%".format(confusion_matrix[target_index][target_index]/sum(confusion_matrix[target_index])*100).rjust(6), end=' |\n')
print('+'+'-------+'*(len(cls_dict.keys())+2))
# print('+'+'-'*39+'+')
def cal_accuracy_rate(confusion_matrix, model_name='', cls_dict=''):
num_total = np.sum(confusion_matrix)
# accuracy rate
num_accuracy = 0
for i in range(num_cls):
num_accuracy += confusion_matrix[i][i]
accuracy_rate = num_accuracy / num_total
print('')
print('------------- ', model_name, ' Accuracy Rate -------------')
print('Number of correct prediction is ', num_accuracy)
print('Number of test data is ', num_total)
print('The accuracy rate is {:.2f}'.format(accuracy_rate*100)+'%')
print('-------------------------------------------------------')
def cal_other_index(confusion_matrix, res_dict={}, model_name='', cls_dict=''):
res_dict['index'] = []
index_dict = {}
num_cls = confusion_matrix.shape[0]
print('')
print('------------- ', model_name, ' ACC, PPV, TPR, TNR, F1-Score -------------')
for cls_index in range(num_cls):
index_dict[cls_dict[cls_index]] = {}
# TR, TN, FP, FN
TP = confusion_matrix[cls_index][cls_index]
TN, FP, FN = 0, 0, 0
for target_index in range(num_cls):
for predict_index in range(num_cls):
if target_index == cls_index and predict_index != cls_index:
FN += confusion_matrix[target_index][predict_index]
elif predict_index == cls_index and target_index != cls_index:
FP += confusion_matrix[target_index][predict_index]
elif target_index != cls_index and predict_index != cls_index:
TN += confusion_matrix[target_index][predict_index]
index_dict[cls_dict[cls_index]]['TP'] = TP
index_dict[cls_dict[cls_index]]['TN'] = TN
index_dict[cls_dict[cls_index]]['FP'] = FP
index_dict[cls_dict[cls_index]]['FN'] = FN
# Accuracy
ACC = (TP + TN) / (TP + TN + FP + FN)
# Precision
PPV = TP / (TP + FP)
# Sensitivity (Recall)
TPR = TP / (TP + FN)
# Specificity
TNR = TN / (TN + FP)
# F1-Score
F1_Score = 2 * PPV * TPR / (PPV + TPR)
# TP, FN, FP, TN
# res_dict[cls_dict[cls_index]] = []
res_dict[cls_dict[cls_index]] = [TP, FN, FP, TN]
res_dict['index'].append([ACC, PPV, TPR, TNR, F1_Score])
# print results
# TP, TN, FP, FN
print('cls: '+cls_dict[cls_index], '\trow: target, col: predicted')
print('+'+'-------+'*3)
print('|'+'\t|', 'Pos'.rjust(6)+'|', 'Neg'.rjust(6)+'|')
print('+'+'-------+'*3)
print('|'+'Pos'.rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['TP']).rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['FN']).rjust(6)+' |')
print('+'+'-------+'*3)
print('|'+'Neg'.rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['FP']).rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['TN']).rjust(6)+' |')
print('+'+'-------+'*3)
print('Accuracy = {:.2f}'.format(ACC*100)+'%')
print('Precision = {:.2f}'.format(PPV*100)+'%')
print('Sensitivity (Recall) = {:.2f}'.format(TPR*100)+'%')
print('Specificity = {:.2f}'.format(TNR*100)+'%')
print('F1-Score = {:.2f}'.format(F1_Score*100)+'%')
print('----------'*2)
# print('Finished - ')
if __name__ == '__main__':
root_path = './results'
save_path = './results/images'
version_suffix = 'v4_3cls'
cls_suffix = "blog"
models_dict = {'resnet18_'+version_suffix+"_"+cls_suffix+'.txt':'resnet18_'+version_suffix, 'mobilenetv2_'+version_suffix+"_"+cls_suffix+'.txt':'mobilenetv2_'+version_suffix}
xls_path = './results/excel/res_'+version_suffix+"_"+cls_suffix+'.xls'
cls_dict = {0:'cat', 1:'dog', 2:'pig'}
num_cls = len(cls_dict.keys())
res_dict = {}
for res_name, model_name in models_dict.items():
res_dict[model_name] = {}
# initialize matrix
confusion_matrix = np.zeros((num_cls, num_cls), dtype = int)
res_path = os.path.join(root_path, res_name)
with open(res_path) as f_src:
lines = f_src.readlines()
scores_dict = {}
for _, cls_name in cls_dict.items():
scores_dict[cls_name] = []
for line in lines:
line_split = line.split('\n')[0].split(' ')
img_name = line_split[0].split('/')[-1]
target = int(line_split[1])
predict = int(line_split[2])
confusion_matrix[target][predict] += 1
# # save image classified error
# if target != predict:
# new_name = line_split[0].split("/")[-1]
# new_path = os.path.join(save_path, model_name)
# new_path = os.path.join(new_path, cls_suffix)
# new_path = os.path.join(new_path, cls_dict[predict])
# new_path = os.path.join(new_path, new_name)
# shutil.copyfile(line_split[0], new_path)
res_dict[model_name]['confusion_matrix'] = confusion_matrix
print_confusion_matrix(confusion_matrix, model_name, cls_dict)
cal_accuracy_rate(confusion_matrix, model_name, cls_dict)
cal_other_index(confusion_matrix, res_dict[model_name], model_name, cls_dict)
write_excel(res_dict, xls_path, cls_dict)
2.2 示例
2.2.1 打印结果
2.2.2 Excel文件结果
3 总结
在分类任务中,比较常用的分析指标有混淆矩阵、准确率、精确率和召回率,对于不同的任务,我们需要根据实际情况选择不同的方法提升不同的指标。