使用sklearn工具包对分类结果进行客观评价(准确率,敏感率,召回率,F1-score, 以及混淆矩阵)

图像分类是CV领域内比较常见的任务,我们评价分类模型时,除了最基本的准确率之外,还会用到各种各样的指标,下面就挑出最具有代表性的几个常用评价指标进行介绍,并且使用sklearn工具包进行实现。

一、图像分类常见评价指标

1.1 混淆矩阵

如下图所示,要了解各个评价指标,首先需要知道混淆矩阵,混淆矩阵中的P表示Positive,即正例或者阳性,N表示Negative,即负例或者阴性。表中FP表示实际为负但被预测为正的样本数量,TN表示实际为负被预测为负的样本的数量,TP表示实际为正被预测为正的样本数量,FN表示实际为正但被预测为负的样本的数量。另外,TP+FP=P’表示所有被预测为正的样本数量,同理FN+TN为所有被预测为负的样本数量,TP+FN为实际为正的样本数量,FP+TN为实际为负的样本数量。

 

1.2 准确率

其实,只要我们得到混淆矩阵这张图,所有的评价指标值都可以根据混淆矩阵对应位置上的值进行计算了。

准确率是分类正确的样本占总样本个数的比例,即混淆矩阵中对角线的值相加除以所有的值

 

ACC=\frac{TP+TN}{TP+TN+FP+FN}

1.3  精确率

在分类任务中,精确率是指每一个类别的精确率。评价模型整体的精确率时,通常都是先计算每个类别的精确率,然后再加权平均得到的。精确率指模型预测为正的样本中实际也为正的样本占被预测为正的样本的比例,精确率越高,对应的误检概率也就越低,也就是更关心正样本的分类准确率。计算公式如下所示,即混淆矩阵每一行对角线的值与该行所有值相加的比值。

PRE=\frac{TP}{TP+FP}

1.4 召回率(敏感率)

在分类任务中,召回率是指每一个类别的召回率。召回率指实际为正的样本中被预测为正的样本所占实际为正的样本的比例,也就是说召回率越高,对应的漏检的概率也就越低,更关心的是负样本被检查出来的概率。具体计算公式如下,即混淆矩阵每一列对角线的值与该列所有值相加的比值:

SEN=\frac{TP}{TP+FN}

1.5  F1-Score

F1 score是精确率和召回率的调和平均值,计算公式为:

F1=\frac{2*PRE*SEN}{PRE+SEN}

二、 代码实现

直接使用sklearn工具包就可以了。主要就是先人为设定两个列表用来存储标签值和预测值,注意的是,要将值转成str类型的,再将两个列表传入到对应的函数即可。SEN 、PRE、 F1的值会直接给出一个列表。

from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

def plot_confusion_matrix(y_true, y_predict):
    # ACC
    print('Accuracy:{:.3f}'.format(accuracy_score(y_true, y_predict))) 
    # SEN PRE F1
    pre= classification_report(y_true, y_predict, target_names=classes,digits=5) 
    print(pre)
  
    
    # 绘制热度图
    classes = ['Normal', 'Covid-19','Viral']
    confusion = confusion_matrix(y_true, y_predict)
    plt.imshow(confusion, interpolation='nearest', cmap=plt.cm.Blues)
    indices = range(len(confusion))
    plt.xticks(indices, classes,fontproperties = 'Times New Roman',fontsize=24)
    plt.yticks(indices, classes,fontproperties = 'Times New Roman',fontsize=24)
    plt.colorbar()
    plt.xlabel('Predictd label',fontsize=24,family='Times New Roman')
    plt.ylabel('True label',fontsize=24,family='Times New Roman')
    iters = np.reshape([[[i, j] for j in range(confusion.shape[0])] for i in range(confusion.shape[1])],(confusion.size, 2))
    for i, j in iters:
        plt.text(j, i, format(confusion[i, j]), va='center', ha='center', fontproperties='Times New Roman',fontsize=24)  #
    # 显示图片
    plt.show()



def test(cnn,test_dir):
    test_data = mydataset.MyDataset_image2(test_dir,"test")
    valid_data_size=len(test_data)
    test_loader = Data.DataLoader(dataset=test_data, batch_size=1, shuffle=True)
    loss_func = torch.nn.CrossEntropyLoss()
    valid_loss = 0.0
    valid_acc = 0.0
    y_true=[]
    y_predict=[]
    for step,(x, y) in enumerate(test_loader):
        if args.use_gpu:
            b_x = Variable(x).cuda()  # batch x
            b_y = Variable(y).cuda()  # batch y
        else:
            b_x = Variable(x)  # batch x
            b_y = Variable(y)  # batch y
        with torch.no_grad():  # this can save much memory
            label=b_y.data.cpu().item()
            y_true.append(str(label))
            test_output= cnn(b_x)
            loss = loss_func(test_output, b_y)
            valid_loss += loss.item() * b_x.size(0)
            ret, predictions = torch.max(test_output.data, 1)
            pre= predictions.data.cpu().item()
            y_predict.append(str(pre))
            correct_counts = predictions.eq(b_y.data.view_as(predictions))
            accuracy = torch.mean(correct_counts.type(torch.FloatTensor))
            valid_acc += accuracy.item() * x.size(0)

    plot_confusion_matrix(y_true, y_predict)
    avg_valid_loss = valid_loss / valid_data_size
    avg_valid_acc = valid_acc / valid_data_size
    print("Test loss:{} Test accuracy:{}".format(avg_valid_loss, avg_valid_acc))
    return avg_valid_loss, avg_valid_acc

 结果如下所示

 

  • 4
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值