【深度学习】【机器学习】分类结果分析指标和方法(混淆矩阵、TP、TN、FP、FN、精确率、召回率)(附源码)

 

目录

0 前言

1 分析指标 

1.1 TP、FP、FN、 TN

1.2 混淆矩阵

1.2.1 二分类

1.2.2 多分类

1.3 二级指标

1.3.1 准确率(Accuracy)

1.3.2 精确率(Precission)

1.3.3 灵敏度(Sensitivity)/ 召回率(Recall)

1.3.4 特异度(Specificity)

1.4 三级指标

1.4.1 F-measure

1.4.2 F1-measure

2 代码

2.1 源码

2.2 示例

2.2.1 打印结果

 2.2.2 Excel文件结果

3 总结


0 前言

前一段时间在做一个分类的项目,主要是应用ResNet18和MobileNetV2模型对数据进行分类,前者主要是用于GPU端,后者主要用于CPU端。模型分类效果主要是通过计算混淆矩阵以及准确率、召回率和F Score来分析,下面对以上指标进行详细的介绍。

视频讲解地址:【深度学习】【机器学习】【代码讲解】分类结果分析指标和方法(混淆矩阵、TP、TN、FP、FN、精确率、召回率)(附源码)_哔哩哔哩 (゜-゜)つロ 干杯~-bilibili

 

1 分析指标 

1.1 TP、FP、FN、 TN

我们先来了解这些缩写的全称:T——True,P——Positive,F——False,N——Negative。故,

TP:真实值为  Positive, 预测值为 Positive,即真的正例;

FP:真实值为 Negative,预测值为 Positive,即假的正例;

FN:真实值为 Positive, 预测值为 Negative,即假的负例;

TN:真实值为 Negative,预测值为 Negative,即真的负例。

其中,在数学上,FP被称为第一类型错误(Type Ⅰ Error),FN被称为第二类型错误(Type Ⅱ Error)。

 

1.2 混淆矩阵

我们将TP、FP、FN、 TN画在一个表格里,这个表格就是混淆矩阵。

1.2.1 二分类

 

对于两分类,一个样本只有两种预测结果:正例或负例。

表1  二分类的混淆矩阵
混淆矩阵预测值
PositiveNegative
真实值PositiveTPFN
NegativeFPTN

举个例子,我们要对猫和狗的图片进行分类,那么我们可以将猫设为正例,将狗设为负例,也就是说,最终的预测结果只有两种:是猫或不是猫(是狗)。假设猫狗的图片各有50和53张,猫狗图片预测正确的分别有45和47张,则混淆矩阵如表2所示

表2  猫狗二分类的混淆矩阵
混淆矩阵预测值
真实值455
647

1.2.2 多分类

很多时候分类任务不仅仅是简单的二分类,可能是三分类或者多分类,这时的混淆矩阵和二分类的有所不同。N分类的混淆矩阵如表3所示

表3  N分类的混淆矩阵
混淆矩阵预测值
class 1class 2······class N
真实值class 1    
class 2    
······    
class N    

举个例子,我们要对猫、狗和猪的图片进行分类,假设猫狗猪的图片各有51、52和49张,猫图片被预测为猫狗猪的图片数分别为47、1和3张,狗图片被预测为猫狗猪的图片数分别为1、49和2张,猪图片预测为猫狗猪的图片数分别为1、0和48张,则混淆矩阵如表4所示

表4 猫狗猪三分类的混淆矩阵
混淆矩阵预测值
真实值4713
1492
1048

 

这三只被错误识别成猪的十有八九是橘猫。

针对每一个类别,我们也可以将其当作是二分类来分析,即预测结果是这类别或者不是这类别,单独画出其混淆矩阵。例如对于猫类别,被预测成狗和猪的结果可以统称为不是猫,这时对于猫的混淆矩阵如表5所示

表5  猫狗猪三分类中猫的混淆矩阵
混淆矩阵预测值
不是猫
真实值474
不是猫299

 

1.3 二级指标

二级指标主要有:准确率、精确率、召回率和特异度。

1.3.1 准确率(Accuracy)

准确率:所有分类正确的结果占总观测值的比重。准确率是针对整个模型的,计算公式是:

(1)多分类模型

ACC = 分类正确的结果 / 总观测值

例如表4,分类准确率ACC =\tfrac{47+49+48}{51+52+49}\approx 94.74%

在多分类模型中,对于类别k,ACC_{class\: k} =\tfrac{TP_{class\: k}+TN_{class\: k}}{TP_{class\: k}+TN_{class\: k}+FP_{class\: k}+FN_{class\: k}}

例如表5,三分类中猫的分类准确度为ACC_{cat} =\tfrac{47+99}{47+99+2+4}\approx 96.05%

(2)二分类模型

ACC =\tfrac{TP+TN}{TP+TN+FP+FN}

例如表2,分类准确率ACC =\tfrac{45+47}{45+47+6+5}\approx 89.32%

 

1.3.2 精确率(Precission)

精确率:在模型预测是Positive的所有结果中,模型预测对的比重,计算公式是:

PPV =\tfrac{TP}{TP+FP}

在表2中,猫的分类精确率为PPV_{cat} =\tfrac{45}{45+6}\approx 88.24%;在表5中,三分类中猫的分类精确率为PPV_{cat} =\tfrac{47}{47+2}\approx 95.92%

精确率对应着预测,简单来说是:“冤假错案”成本高,“漏网之鱼”成本低。举个例子,我们要判断邮件是否为垃圾邮件,是则True,否则False。如果一封垃圾邮件被误判断成正常邮件,那么我们可能只需要浪费几秒钟时间点开查看;但如果一封很重要的邮件被丢进垃圾箱里了,那可能会导致我们错过很重要的信息。这时候FP要尽量小,在TP不变的情况下,PPV要尽量大。

在信息检索领域,精确度也称为查准率。

 

1.3.3 灵敏度(Sensitivity)/ 召回率(Recall)

召回率:在真实值是Positive的所有结果中,模型预测对的比重,计算公式是:

TPR =\tfrac{TP}{TP+FN}

在表2中,猫的召回率为TPR_{cat} =\tfrac{45}{45+5}\approx 90.00%;在表5中,三分类中猫的分类召回率为TPR_{cat} =\tfrac{47}{47+4}\approx 92.16%

召回率对应着样本(真实值),要求分类结果“大而全”,注重量,简单来说是:“冤假错案”成本低,“漏网之鱼”成本高。举个例子,我们要判断某一时间是否会发生地震,是则True,否则False。如果系统预测到今天会发生地震,提前发出预警,就算最终不发生地震,民众也就浪费点时间去避难;但如果真的发生地震了而没有预测出来,那就会导致人民的生命财产受到严重的损失。

在信息检索领域,召回率也称为查全率。

 

1.3.4 特异度(Specificity)

特异度:在真实值是Negative的所有结果中,模型预测对的比重,计算公式是:

TNR =\tfrac{TN}{TN+FP}

 

1.4 三级指标

 

1.4.1 F-measure

F-measure是Precision和Recall的加权调和平均,计算公式是:

F =\tfrac{\left (\alpha ^{2}+1 \right )\cdot Precision\cdot Recall}{\alpha ^{2}\cdot \left (Precision+Recall \right )}

 

1.4.2 F1-measure

当α=1时,F1\: Score =\tfrac{2\cdot Precision\cdot Recall}{Precision+Recall}

F1 Score指标的取值范围是[0, 1],F1 Score越接近于0,模型的输出结果越差;F1 Score越接近于1,模型的输出结果越好。

 

2 代码

 代码主要是分析了ResNet18和MobileNetV2两个模型的三分类结果,统计混淆矩阵并打印输出,计算了二级和三级指标并打印输出,最后将混淆矩阵和二三级指标输出保存到Excel文件中,方便后续分析处理。

2.1 源码

import os
import numpy as np
import xlwt
import shutil

# 设置表格样式
def set_style(name, height, bold=False):
    style = xlwt.XFStyle()
    font = xlwt.Font()
    font.name = name
    font.bold = bold
    font.color_index = 4
    font.height = height
    style.font = font
    
    borders = xlwt.Borders()
    borders.left = 1
    borders.right = 1
    borders.top = 1
    borders.bottom = 1
    borders.bottom_colour=0x3A
    style.borders = borders

    return style


# 写Excel
def write_excel(info_dict, res_name="./res.xls", cls_dict=""):
    f = xlwt.Workbook()
    for sheet_name, val_dict in info_dict.items():
        sheet = f.add_sheet(sheet_name, cell_overwrite_ok=True)

        # Confusion Matrix
        # row0 = ["Confusion Matrix", "class 0", "class 1", ..., "class N", "Pass Rate (%)"]
        # colum0 = ["class 0", "class 1", ..., "class N", "model"]
        row0 = ["Confusion Matrix"]
        colum0 = []
        for _, cls_name in cls_dict.items():
            row0.append(cls_name)
            colum0.append(cls_name)
        row0.append("Pass Rate (%)")
        colum0.append("Model")
        
        # first row
        for i in range(0, len(row0)):
            sheet.write(0, i, row0[i], set_style('Times New Roman',220,True))
        # first col
        for i in range(0, len(colum0)):
            sheet.write(i+1, 0, colum0[i], set_style('Times New Roman',220,True))
        confusion_matrix = val_dict['confusion_matrix']
        for row in range(confusion_matrix.shape[0]):
            for col in range(confusion_matrix.shape[1]):
                sheet.write(row+1, col+1, int(confusion_matrix[row][col]), set_style('Times New Roman',220,False))
        # Accuracy of each class
        for row in range(confusion_matrix.shape[0]):
            if sum(confusion_matrix[row])*100 == 0:
                ACC = -1
            else:
                ACC = round(confusion_matrix[row][row]/sum(confusion_matrix[row])*100, 2)
            sheet.write(row+1, confusion_matrix.shape[1]+1, ACC, set_style('Times New Roman',220,False))
        # Accuracy of the model
        num_correct = 0
        for cls_index in range(confusion_matrix.shape[0]):
            num_correct += confusion_matrix[cls_index][cls_index]
        ACC_model = round(num_correct/sum(sum(confusion_matrix))*100, 2)
        sheet.write(confusion_matrix.shape[0]+1, confusion_matrix.shape[1]+1, ACC_model, set_style('Times New Roman',220,False))
        for i in range(1, confusion_matrix.shape[1]+1):
            sheet.write(confusion_matrix.shape[0]+1, i, '', set_style('Times New Roman',220,False))

        sep = 2
        # Index - Accuracy (ACC), Precision (PPV), Sensitivity (Recall, TPR), Specificity (TNR), F1-Score
        # first row
        first_row = confusion_matrix.shape[0] + 2 + sep
        # row0 = ["Index (%)", "class 0", "class 1", ..., "class N"]
        # colum0 = ["Accuracy", "Precision", "Sensitivity (Recall)", "Specificity", "F1-Score"]
        row0 = ["Index (%)"]
        for _, cls_name in cls_dict.items():
            row0.append(cls_name)
        colum0 = ["Accuracy", "Precision", "Sensitivity (Recall)", "Specificity", "F1-Score"]

        for i in range(0, len(row0)):
            sheet.write(first_row, i, row0[i], set_style('Times New Roman',220,True))
        # first col
        for i in range(0, len(colum0)):
            sheet.write(i+1+first_row, 0, colum0[i], set_style('Times New Roman',220,True))
        index_list = val_dict['index']
        for row in range(len(index_list)):
            for col in range(len(index_list[row])):
                sheet.write(col+1+first_row, row+1, round(index_list[row][col]*100, 2), set_style('Times New Roman',220,False))

        sep = 1
        # TP, TN, FP, FN of each class
        # first row
        first_col = confusion_matrix.shape[1] + 2 + sep
        row0 = ["Positive", "Negative"]
        colum0 = ["Positive", "Negative"]
        for cls_index in range(len(cls_dict.keys())):
            sheet.write(cls_index*len(colum0)+cls_index*2, first_col, cls_dict[cls_index], set_style('Times New Roman',220,True))
            for i in range(0, len(row0)):
                sheet.write(cls_index*len(colum0)+cls_index*2, i+1+first_col, row0[i], set_style('Times New Roman',220,False))
            # first col
            for i in range(0, len(colum0)):
                sheet.write(cls_index*len(colum0)+cls_index*2+i+1, first_col, colum0[i], set_style('Times New Roman',220,False))
            # value - TP, FN, FP, TN
            value = val_dict[cls_dict[cls_index]]
            sheet.write(cls_index*len(colum0)+cls_index*2+1, first_col+1, int(value[0]), set_style('Times New Roman',220,False))
            sheet.write(cls_index*len(colum0)+cls_index*2+1, first_col+2, int(value[1]), set_style('Times New Roman',220,False))
            sheet.write(cls_index*len(colum0)+cls_index*2+2, first_col+1, int(value[2]), set_style('Times New Roman',220,False))
            sheet.write(cls_index*len(colum0)+cls_index*2+2, first_col+2, int(value[3]), set_style('Times New Roman',220,False))
        
        # sheet.write(1,3,'2006/12/12')
        # sheet.write_merge(6,6,1,3,'未知')#合并行单元格
        # sheet.write_merge(1,2,3,3,'打游戏')#合并列单元格
        # sheet.write_merge(4,5,3,3,'打篮球')

        # Confusion Matrix for each class

    f.save(res_name)



def print_confusion_matrix(confusion_matrix, model_name='', cls_dict=''):
    num_cls = confusion_matrix.shape[0]
    print('')
    print('------------- ', model_name, ' Confusion Matrix -------------')
    print('row: target, col: predicted')
    # print('+'+'-'*47+'+')
    print('+'+'-------+'*(len(cls_dict.keys())+2))
    print('|'+'\t', end='')
    for predict_index in range(num_cls):
        # if predict_index == num_cls - 1:
        #     end_str = '\t' + '|' + '\n'
        # else:
        #     end_str = '\t'
        end_str = '\t'
        print('|'+cls_dict[predict_index].rjust(6), end=end_str)
    print('|'+'Acc'.rjust(6), end=' |\n')
    print('+'+'-------+'*(len(cls_dict.keys())+2))
    for target_index in range(num_cls):
        print('|'+cls_dict[target_index].rjust(6), end='\t')
        for predict_index in range(num_cls):
            # if predict_index == num_cls - 1:
            #     end_str = '\t' + '|'+'\n'
            # else:
            #     end_str = '\t'
            end_str = '\t'
            print('|'+str(confusion_matrix[target_index][predict_index]).rjust(6), end=end_str)
        
        print('|'+ "{:.2f}%".format(confusion_matrix[target_index][target_index]/sum(confusion_matrix[target_index])*100).rjust(6), end=' |\n')
        print('+'+'-------+'*(len(cls_dict.keys())+2))
    # print('+'+'-'*39+'+')


def cal_accuracy_rate(confusion_matrix, model_name='', cls_dict=''):
    num_total = np.sum(confusion_matrix)

    # accuracy rate
    num_accuracy = 0
    for i in range(num_cls):
        num_accuracy += confusion_matrix[i][i]
    accuracy_rate = num_accuracy / num_total
    print('')
    print('------------- ', model_name, ' Accuracy Rate -------------')
    print('Number of correct prediction is ', num_accuracy)
    print('Number of test data is ', num_total)
    print('The accuracy rate is {:.2f}'.format(accuracy_rate*100)+'%')
    print('-------------------------------------------------------')


def cal_other_index(confusion_matrix, res_dict={}, model_name='', cls_dict=''):
    res_dict['index'] = []
    index_dict = {}
    num_cls = confusion_matrix.shape[0]
    print('')
    print('------------- ', model_name, ' ACC, PPV, TPR, TNR, F1-Score -------------')
    for cls_index in range(num_cls):
        index_dict[cls_dict[cls_index]] = {}
        # TR, TN, FP, FN
        TP = confusion_matrix[cls_index][cls_index]
        TN, FP, FN = 0, 0, 0
        for target_index in range(num_cls):
            for predict_index in range(num_cls):
                if target_index == cls_index and predict_index != cls_index:
                    FN += confusion_matrix[target_index][predict_index]
                elif predict_index == cls_index and target_index != cls_index:
                    FP += confusion_matrix[target_index][predict_index]
                elif target_index != cls_index and predict_index != cls_index:
                    TN += confusion_matrix[target_index][predict_index]
        index_dict[cls_dict[cls_index]]['TP'] = TP
        index_dict[cls_dict[cls_index]]['TN'] = TN
        index_dict[cls_dict[cls_index]]['FP'] = FP
        index_dict[cls_dict[cls_index]]['FN'] = FN
        # Accuracy
        ACC = (TP + TN) / (TP + TN + FP + FN)
        # Precision
        PPV = TP / (TP + FP)
        # Sensitivity (Recall)
        TPR = TP / (TP + FN)
        # Specificity
        TNR = TN / (TN + FP)
        # F1-Score
        F1_Score = 2 * PPV * TPR / (PPV + TPR)

        # TP, FN, FP, TN
        # res_dict[cls_dict[cls_index]] = []
        res_dict[cls_dict[cls_index]] = [TP, FN, FP, TN]

        res_dict['index'].append([ACC, PPV, TPR, TNR, F1_Score])

        # print results
        # TP, TN, FP, FN
        print('cls: '+cls_dict[cls_index], '\trow: target, col: predicted')
        print('+'+'-------+'*3)
        print('|'+'\t|', 'Pos'.rjust(6)+'|', 'Neg'.rjust(6)+'|')
        print('+'+'-------+'*3)
        print('|'+'Pos'.rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['TP']).rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['FN']).rjust(6)+' |')
        print('+'+'-------+'*3)
        print('|'+'Neg'.rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['FP']).rjust(6)+' |'+str(index_dict[cls_dict[cls_index]]['TN']).rjust(6)+' |')
        print('+'+'-------+'*3)

        print('Accuracy = {:.2f}'.format(ACC*100)+'%')
        print('Precision = {:.2f}'.format(PPV*100)+'%')
        print('Sensitivity (Recall) = {:.2f}'.format(TPR*100)+'%')
        print('Specificity = {:.2f}'.format(TNR*100)+'%')
        print('F1-Score = {:.2f}'.format(F1_Score*100)+'%')
        print('----------'*2)
        # print('Finished - ')
        

if __name__ == '__main__':

    root_path = './results'
    save_path = './results/images'

    version_suffix = 'v4_3cls'
    cls_suffix = "blog"

    models_dict = {'resnet18_'+version_suffix+"_"+cls_suffix+'.txt':'resnet18_'+version_suffix, 'mobilenetv2_'+version_suffix+"_"+cls_suffix+'.txt':'mobilenetv2_'+version_suffix}
    xls_path = './results/excel/res_'+version_suffix+"_"+cls_suffix+'.xls'

    cls_dict = {0:'cat', 1:'dog', 2:'pig'}

    num_cls = len(cls_dict.keys())
    res_dict = {}
    
    for res_name, model_name in models_dict.items():
        res_dict[model_name] = {}
        # initialize matrix
        confusion_matrix = np.zeros((num_cls, num_cls), dtype = int)
        res_path = os.path.join(root_path, res_name)
        with open(res_path) as f_src:
            lines = f_src.readlines()

            scores_dict = {}
            for _, cls_name in cls_dict.items():
                scores_dict[cls_name] = []

            for line in lines:
                line_split = line.split('\n')[0].split(' ')
                img_name = line_split[0].split('/')[-1]
                target = int(line_split[1])
                predict = int(line_split[2])

                confusion_matrix[target][predict] += 1

                # # save image classified error
                # if target != predict:
                #     new_name = line_split[0].split("/")[-1]
                #     new_path = os.path.join(save_path, model_name)
                #     new_path = os.path.join(new_path, cls_suffix)
                #     new_path = os.path.join(new_path, cls_dict[predict])
                #     new_path = os.path.join(new_path, new_name)
                #     shutil.copyfile(line_split[0], new_path)

        res_dict[model_name]['confusion_matrix'] = confusion_matrix

        print_confusion_matrix(confusion_matrix, model_name, cls_dict)
        cal_accuracy_rate(confusion_matrix, model_name, cls_dict)
        cal_other_index(confusion_matrix, res_dict[model_name], model_name, cls_dict)

    write_excel(res_dict, xls_path, cls_dict)

2.2 示例

2.2.1 打印结果

 2.2.2 Excel文件结果

 

3 总结

在分类任务中,比较常用的分析指标有混淆矩阵、准确率、精确率和召回率,对于不同的任务,我们需要根据实际情况选择不同的方法提升不同的指标。

 

 

 

 

 

 

  • 19
    点赞
  • 102
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值