分类任务绘制混淆矩阵


前言

在二分类或者多分类的过程中,无论是在训练还是在验证的过程中的每一个epoch都可以通过绘制混淆矩阵来查看准确率,精确率,灵敏度/召回率,特异度。因此绘制混淆矩阵非常重要,此篇文章可以了解并通过代码进行使用confusion matrix


一、混淆矩阵包括什么

在这里插入图片描述

二、计算公式

在这里插入图片描述

三、源码

import os
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable



class ConfusionMatrix(object):
    """
    num_classes:分类网络的类别个数 21
    labels:对应的分类类别列表
    """
    def __init__(self, num_classes: int, labels: list, save_dir: str):
        self.matrix = np.zeros((num_classes, num_classes))
        self.num_classes = num_classes
        self.labels = labels
        self.save_dir = save_dir

    '''
    preds:预测的值
    labels:真实的标签
    '''
    def update(self, preds, labels):
        for p, t in zip(preds, labels):
            self.matrix[p, t] += 1

    def summary(self):
        # calculate accuracy
        sum_TP = 0
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]
        acc = sum_TP / np.sum(self.matrix)
        print("the model accuracy is ", acc)

        # precision, recall, specificity
        table = PrettyTable()
        table.field_names = ["", "Precision", "Recall", "Specificity"]
        for i in range(self.num_classes):
            TP = self.matrix[i, i]
            FP = np.sum(self.matrix[i, :]) - TP
            FN = np.sum(self.matrix[:, i]) - TP
            TN = np.sum(self.matrix) - TP - FP - FN
            Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
            Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
            Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
            table.add_row([self.labels[i], Precision, Recall, Specificity])
        print(table)

    def plot(self,epoch,index):
        matrix = self.matrix
        print(matrix)
        plt.figure(figsize=(20, 20))
        # color is come cm to blue.
        plt.imshow(matrix, cmap=plt.cm.Blues)

        # 设置x轴坐标label
        plt.xticks(range(self.num_classes), self.labels, rotation=45)
        # 设置y轴坐标label
        plt.yticks(range(self.num_classes), self.labels)
        # 显示colorbar
        plt.colorbar()
        plt.xlabel('True Labels')
        plt.ylabel('Predicted Labels')
        plt.title('Confusion matrix')

        # 在图中标注数量/概率信息
        thresh = matrix.max() / 2
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                # 注意这里的matrix[y, x]不是matrix[x, y]
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black")
        # 图形显示紧凑
        # plt.tight_layout()
        # plt.figure(figsize=(12,12))
        matrix_save_path = os.path.join(self.save_dir, f'{epoch}_{index}_confusion_matrix.png')
        plt.savefig(matrix_save_path, bbox_inches='tight')
        # plt.show()
        # return plt,matrix

if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    json_label_path = 'E:/recovery_source_code/Movement_Classification/utils/class_indices.json'
    assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path)
    with open(json_label_path, 'r', encoding="utf-8") as file:
        class_indict = json.load(file)
    labels = [value for _, value in class_indict.items()]
    confusion = ConfusionMatrix(num_classes=21, labels=labels)

四、训练和测试过程中使用

训练代码:

#代码不全 主要查看confusion_matrix类如何使用
mean_total_loss = []
json_label_path = 'E:/recovery_source_code/Movement_Classification/utils/class_indices.json'
assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path)
with open(json_label_path, 'r', encoding="utf-8") as file:
    class_indict = json.load(file)
confusion_matrix_figure_path = "E:/recovery_source_code/Movement_Classification/Result/Confusion_matrix"
for epoch in range(epochs):
	 l = [key for key, _ in class_indict.items()]
     confusion = ConfusionMatrix(num_classes=21, labels=l, save_dir=confusion_matrix_figure_path)
     print(f"***********************************{epoch+1}/{epochs}************************************")
    model.train(True)
    for batch_idx, (data, labels, file) in enumerate(dataloader):
    	data = data.to(device=device)
        labels = labels.to(device=device)
        # data = data.permute(0, 3, 1, 2).unsqueeze(4)
        data = data.permute(0, 3, 1, 2)
        labels = labels.to(dtype = torch.long)
        # 前向传播
        outputs = model(data)
        max_indices = torch.argmax(outputs, dim=2)
        max_indices_cpu = np.squeeze(max_indices.cpu().numpy())
        labels_cpu = np.squeeze(labels.cpu().numpy().astype(int))
        confusion.update(max_indices_cpu, labels_cpu)
        if batch_idx %100==0:
            confusion.plot(epoch,batch_idx)
            confusion.summary()
        outputs = outputs.permute(0, 2, 1)
        loss = criterion(outputs, labels)
        total_loss = total_loss + loss
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    

测试代码:

#代码不全主要查看confusion_matrix类如何在测试代码中使用
data = data.to(device=device)
        labels = labels.to(device=device)
        # data = data.permute(0, 3, 1, 2).unsqueeze(4)
        data = data.permute(0, 3, 1, 2)
        labels = labels.to(dtype = torch.long)
        # 前向传播
        outputs = model(data)
        max_indices = torch.argmax(outputs, dim=2)
        max_indices_cpu = np.squeeze(max_indices.cpu().numpy())
        labels_cpu = np.squeeze(labels.cpu().numpy().astype(int))
        confusion.update(max_indices_cpu, labels_cpu)
        if batch_idx %100==0:
            confusion.plot(epoch,batch_idx)
            confusion.summary()
        outputs = outputs.permute(0, 2, 1)
        loss = criterion(outputs, labels)
        total_loss = total_loss + loss
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

总结

现在对于26个骨骼点时序动作分类模型效果不行,还需要进一步编写代码,预处理,训练,测试,demo类编写等任务.

  • 8
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值