python画混淆矩阵

#coding=utf-8
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix

save_flg = True

# confusion = confusion_matrix(y_test, y_pred)
confusion = np.array([[97, 2,  0,  0, 1, 0],
                     [ 4, 94,  1,  21, 0, 0],
                     [ 3,  2, 95,  0, 0, 0],
                     [ 0,  0,  0, 98, 2, 0],
                     [ 3,  1,  0,  0,96, 0],
                     [ 0,  1,  3,  0, 6,90]])

plt.figure(figsize=(5, 5))  #设置图片大小


# 1.热度图,后面是指定的颜色块,cmap可设置其他的不同颜色
plt.imshow(confusion, cmap=plt.cm.Blues)
plt.colorbar()   # 右边的colorbar


# 2.设置坐标轴显示列表
indices = range(len(confusion))    
classes = ['A', 'B', 'C', 'D', 'E', 'F']  
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
plt.xticks(indices, classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
plt.yticks(indices, classes)


# 3.设置全局字体
# 在本例中,坐标轴刻度和图例均用新罗马字体['TimesNewRoman']来表示
# ['SimSun']宋体;['SimHei']黑体,有很多自己都可以设置
plt.rcParams['font.sans-serif'] = ['SimHei']  
plt.rcParams['axes.unicode_minus'] = False


# 4.设置坐标轴标题、字体
# plt.ylabel('True label')
# plt.xlabel('Predicted label')
# plt.title('Confusion matrix')

plt.xlabel('预测值')
plt.ylabel('真实值')
plt.title('混淆矩阵', fontsize=12, fontfamily="SimHei")  #可设置标题大小、字体


# 5.显示数据
normalize = False
fmt = '.2f' if normalize else 'd'
thresh = confusion.max() / 2.

for i in range(len(confusion)):    #第几行
    for j in range(len(confusion[i])):    #第几列
        plt.text(j, i, format(confusion[i][j], fmt),
        fontsize=16,  # 矩阵字体大小
        horizontalalignment="center",  # 水平居中。
        verticalalignment="center",  # 垂直居中。
        color="white" if confusion[i, j] > thresh else "black")


#6.保存图片
if save_flg:  
    plt.savefig("./picture/confusion_matrix.png")


# 7.显示
plt.show()

效果:

 

  • 18
    点赞
  • 123
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
混淆矩阵是用于评估分类模型性能的一种工具,它可以显示模型在不同类别上的预测结果和真实标签之间的对应关系。在PyTorch中,我们可以使用混淆矩阵来评估模型的分类准确性。 首先,我们需要导入必要的库和函数进行混淆矩阵的计算和可视化。可以参考和中的代码实现部分。 1. 数据集:在计算混淆矩阵之前,我们需要准备好一个验证集,该验证集包含模型预测的结果和真实标签。可以参考中的代码实现部分。 2. 代码:混淆矩阵类:在PyTorch中,可以通过编写一个混淆矩阵类来计算混淆矩阵。可以参考中的代码实现部分。 3. 在验证集上计算相关指标:使用混淆矩阵类计算验证集上的混淆矩阵,并计算相关指标,例如准确率、召回率、F1分数等。可以参考中的代码实现部分。 4. 结果:通过计算混淆矩阵和相关指标,我们可以得到模型在验证集上的分类性能结果。可以将混淆矩阵可视化,以更直观地理解模型的分类表现。可以参考中的代码实现部分。 绘制混淆矩阵的过程包括以下步骤: 1. 将混淆矩阵赋值给一个变量。 2. 打印混淆矩阵。 3. 使用imshow函数展示混淆矩阵,设置颜色变换从白色到蓝色。 4. 使用xticks函数将x轴的信息(0~num_classes-1)替换为标签的类别,并将x轴旋转45°。 5. 同理,使用yticks函数将y轴的信息替换为标签的类别。 6. 添加一个右侧颜色条,用来表示混淆矩阵中数值的密集程度,颜色越深表示数值越密集。 7. 设置横坐标为真实标签,纵坐标为预测标签。 8. 添加图像标题,例如"Confusion matrix"。 通过以上步骤,我们可以绘制出一个具有标签类别的混淆矩阵图像,该图像可以帮助我们更好地理解模型在不同类别上的分类表现。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [混淆矩阵:用于多分类模型评估(pytorch)](https://blog.csdn.net/weixin_43760844/article/details/115208925)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *3* [人工智能学习07--pytorch11--分类网络:使用pytorch和tensorflow计算分类模型的混淆矩阵](https://blog.csdn.net/AMWICD/article/details/129443938)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值