动手画混淆矩阵(Confusion Matrix)(含代码)

更多文章推荐

网上关于混淆矩阵的代码参差不齐,没找到可用的线程的代码,所以自己尝试写了下

1、混淆矩阵:Confusion Matrix

首先它长这样:
在这里插入图片描述

怎么看?
Confusion Matrix最广泛的应用应该是分类,比如图中是7分类的真实标签和预测标签的效果。
首先图中表明了纵轴是truth label,横轴是predicted label,那么对于第一行第一个0.60的含义是:本来是angry标签的图,我的模型正确分类成angry的比例是60%,也即是angry这一类模型分类正确的精度只有60%。同时模型将angry分类成了happy的图占比0.04%,其他的以此类推。

注意:因为本身是angry,模型预测成7种类的数量占比。所以每一行的和为100%。

同时对于fear标签,模型分类成fear的占比41%,分类成sad的占比为20%,我们可以认为模型不能很好区分fear和sad两种类别。

2、怎么画(新)?

这里直接给出代码,在下一节中直接使用即可

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix


def draw_confusion_matrix(label_true, label_pred, label_name, title="Confusion Matrix", pdf_save_path=None, dpi=100):
    """

    @param label_true: 真实标签,比如[0,1,2,7,4,5,...]
    @param label_pred: 预测标签,比如[0,5,4,2,1,4,...]
    @param label_name: 标签名字,比如['cat','dog','flower',...]
    @param title: 图标题
    @param pdf_save_path: 是否保存,是则为保存路径pdf_save_path=xxx.png | xxx.pdf | ...等其他plt.savefig支持的保存格式
    @param dpi: 保存到文件的分辨率,论文一般要求至少300dpi
    @return:

    example:
            draw_confusion_matrix(label_true=y_gt,
                          label_pred=y_pred,
                          label_name=["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"],
                          title="Confusion Matrix on Fer2013",
                          pdf_save_path="Confusion_Matrix_on_Fer2013.png",
                          dpi=300)

    """
    cm = confusion_matrix(y_true=label_true, y_pred=label_pred, normalize='true')

    plt.imshow(cm, cmap='Blues')
    plt.title(title)
    plt.xlabel("Predict label")
    plt.ylabel("Truth label")
    plt.yticks(range(label_name.__len__()), label_name)
    plt.xticks(range(label_name.__len__()), label_name, rotation=45)

    plt.tight_layout()

    plt.colorbar()

    for i in range(label_name.__len__()):
        for j in range(label_name.__len__()):
            color = (1, 1, 1) if i == j else (0, 0, 0)  # 对角线字体白色,其他黑色
            value = float(format('%.2f' % cm[j, i]))
            plt.text(i, j, value, verticalalignment='center', horizontalalignment='center', color=color)

    # plt.show()
    if not pdf_save_path is None:
        plt.savefig(pdf_save_path, bbox_inches='tight', dpi=dpi)


3、怎么用?

给出一个简单的实例:

y_gt=[]
y_pred=[]
for index, (labels, imgs) in enumerate(test_loader):
    labels_pd = model(imgs)
    predict_np = np.argmax(labels_pd.cpu().detach().numpy(), axis=-1)   # array([0,5,1,6,3,...],dtype=int64)
    labels_np = labels.numpy()                                          # array([0,5,0,6,2,...],dtype=int64)
	
	y_pred.append(predict_np)
	y_gt.append(labels_np)
    
draw_confusion_matrix(label_true=y_gt,			# y_gt=[0,5,1,6,3,...]
                      label_pred=y_pred,	    # y_pred=[0,5,1,6,3,...]
                      label_name=["An", "Di", "Fe", "Ha", "Sa", "Su", "Ne"],
                      title="Confusion Matrix on Fer2013",
                      pdf_save_path="Confusion_Matrix_on_Fer2013.jpg",
                      dpi=300)
  • cpu().detach():从device上获取数据
  • .numpy():将tensor类型转换为numpy类型

在我的模型上的结果:
在这里插入图片描述

评论 73
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我是一个对称矩阵

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值