混淆矩阵:用于多分类模型评估(pytorch)——总结2

43 篇文章 6 订阅

1. 混淆矩阵介绍

这里不多说,可参考

2. 代码实现(自己设计实现,不用sklearn库)

2.1 数据集

此数据集用于多分类任务(检测番茄叶片病虫害)。这里测试的数据集一共1250张图,1000张用于训练,250张用于验证,共分为5个类别。数据集结构如下:
在这里插入图片描述
数据集部分图片展示:
在这里插入图片描述

2.2 代码:混淆矩阵类

计算accuracy、kappa、precision、recall、specificity

class ConfusionMatrix(object):


    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))#初始化混淆矩阵,元素都为0
        self.num_classes = num_classes#类别数量,本例数据集类别为5
        self.labels = labels#类别标签

    def update(self, preds, labels):
        for p, t in zip(preds, labels):#pred为预测结果,labels为真实标签
            self.matrix[p, t] += 1#根据预测结果和真实标签的值统计数量,在混淆矩阵相应位置+1

    def summary(self):#计算指标函数
        # calculate accuracy
        sum_TP = 0
        # 计算测试样本的总数
        n = np.sum(self.matrix)
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]#混淆矩阵对角线的元素之和,也就是分类正确的数量
        acc = sum_TP / n#总体准确率
        print("the model accuracy is ", acc)
		
		# kappa
        sum_po = 0
        sum_pe = 0
        for i in range(len(self.matrix[0])):
            sum_po += self.matrix[i][i]
            row = np.sum(self.matrix[i, :])
            col = np.sum(self.matrix[:, i])
            sum_pe += row * col
        po = sum_po / n
        pe = sum_pe / (n * n)
        # print(po, pe)
        kappa = round((po - pe) / (1 - pe), 3)
        #print("the model kappa is ", kappa)
        
        # precision, recall, specificity
        table = PrettyTable()#创建一个表格
        table.field_names = ["", "Precision", "Recall", "Specificity"]
        for i in range(self.num_classes):#精确度、召回率、特异度的计算
            TP = self.matrix[i, i]
            FP = np.sum(self.matrix[i, :]) - TP
            FN = np.sum(self.matrix[:, i]) - TP
            TN = np.sum(self.matrix) - TP - FP - FN

            Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
            Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.#每一类准确度
            Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.

            table.add_row([self.labels[i], Precision, Recall, Specificity])
        print(table)
        return str(acc)

    def plot(self):#绘制混淆矩阵
        matrix = self.matrix
        print(matrix)
        plt.imshow(matrix, cmap=plt.cm.Blues)

        # 设置x轴坐标label
        plt.xticks(range(self.num_classes), self.labels, rotation=45)
        # 设置y轴坐标label
        plt.yticks(range(self.num_classes), self.labels)
        # 显示colorbar
        plt.colorbar()
        plt.xlabel('True Labels')
        plt.ylabel('Predicted Labels')
        plt.title('Confusion matrix (acc='+self.summary()+')')

        # 在图中标注数量/概率信息
        thresh = matrix.max() / 2
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                # 注意这里的matrix[y, x]不是matrix[x, y]
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black")
        plt.tight_layout()
        plt.show()


2.3 在验证集上计算相关指标

在每个epoch计算一次指标,输出混淆矩阵并可视化

训练过程验证部分代码如下:

		class_indict = config.tomato_DICT
		#tomato_DICT = {'0': 'Bacterial_spot', '1': 'Early_blight', '2': 'healthy', '3': 'Late_blight', '4': 'Leaf_Mold'}
		# 标签名字列表
        label = [label for _, label in class_indict.items()]
        confusion = ConfusionMatrix(num_classes=config.NUM_CLASSES, labels=label)
        #实例化混淆矩阵,这里NUM_CLASSES = 5

        with torch.no_grad():
            model.eval()#验证
            for j, (inputs, labels) in enumerate(val_data):
                inputs = inputs.to(device)
                labels = labels.to(device)
                output = model(inputs)#分类网络的输出,分类器用的softmax,即使不使用softmax也不影响分类结果。
                loss = loss_function(output, labels)
                valid_loss += loss.item() * inputs.size(0)
                ret, predictions = torch.max(output.data, 1)#torch.max获取output最大值以及下标,predictions即为预测值(概率最大),这里是获取验证集每个batchsize的预测结果
                #confusion_matrix
                confusion.update(predictions.cpu().numpy(), labels.cpu().numpy())


            confusion.plot()
            confusion.summary()

2.4 结果

训练30个epoch,在第29个epoch取得最好的结果:
在这里插入图片描述

在这里插入图片描述

真实标签和预测标签在不同位置(x坐标和y坐标)都是可以的,看个人习惯,计算的时候注意就行了

转载自:

https://blog.csdn.net/weixin_43760844/article/details/115208925

  • 9
    点赞
  • 49
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值