pytorch分类模型绘制混淆矩阵及可视化

该博客介绍了如何在PyTorch中计算并可视化分类模型的混淆矩阵。首先,通过定义一个空的混淆矩阵,并在测试数据集上迭代,使用torch.no_grad()来减少GPU使用。接着,利用confusion_matrix函数更新混淆矩阵。最后,将混淆矩阵转换为numpy数组,计算每个类别的识别正确个数和准确率,并进行可视化。混淆矩阵的可视化通过matplotlib完成,显示了每个类别的识别情况。
摘要由CSDN通过智能技术生成

pytorch分类模型绘制混淆矩阵及可视化

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:
在这里插入图片描述

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:
在这里插入图片描述

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

  • 22
    点赞
  • 196
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 15
    评论
PyTorch中,可以使用matplotlib库来绘制混淆矩阵。下面是一个示例代码,展示了如何使用PyTorch和matplotlib来混淆矩阵: ```python import numpy as np import matplotlib.pyplot as plt # 定义混淆矩阵 def confusion_matrix(preds, labels, conf_matrix): for p, t in zip(preds, labels): conf_matrix\[p, t\] += 1 return conf_matrix # 绘制混淆矩阵 def plot_matrix(matrix, labels): plt.imshow(matrix, cmap=plt.cm.Blues) for x in range(len(labels)): for y in range(len(labels)): info = matrix\[y, x\] plt.text(x, y, info, verticalalignment='center', horizontalalignment='center') plt.xticks(range(len(labels)), labels, rotation=45) plt.yticks(range(len(labels)), labels) plt.xlabel('Predicted Label') plt.ylabel('True Label') plt.title('Confusion Matrix') plt.colorbar() plt.show() # 示例代码 preds = \[0, 1, 2, 1, 0, 2, 2\] labels = \[0, 1, 2, 1, 0, 1, 2\] conf_matrix = np.zeros((3, 3)) conf_matrix = confusion_matrix(preds, labels, conf_matrix) plot_matrix(conf_matrix, \['Label 0', 'Label 1', 'Label 2'\]) ``` 这段代码首先定义了一个混淆矩阵函数`confusion_matrix`,用于计算混淆矩阵。然后定义了一个绘制混淆矩阵的函数`plot_matrix`,该函数使用matplotlib库来绘制混淆矩阵。最后,通过调用这两个函数,可以计算准确率并绘制混淆矩阵。 希望这个示例代码能够帮助你理解如何使用PyTorch和matplotlib来混淆矩阵。如果有任何问题,请随时提问。 #### 引用[.reference_title] - *1* *2* [混淆矩阵绘制](https://blog.csdn.net/qq_45470799/article/details/123737859)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [pytorch绘制混淆矩阵](https://blog.csdn.net/qq_18617009/article/details/103345308)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论 15
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

王延凯的博客

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值