python matplotlib绘制混淆矩阵并配色

步骤1:网络测试结果保存

以pytorch为例,在测试阶段保存结果的参考代码为:

resultTxtName = "result.txt"
resultfiledir = os.path.join(web_dir,resultTxtName)
f_result=open(resultfiledir, "a+")

if opt.which_model_netG[-5:] == "class":
	clss_loss_sum = 0
	right_numbers = 0
	for i, data in enumerate(dataset):
		# if i >= opt.how_many:
		# 	break
		counter = i + 1
		# pdb.set_trace()
		model.set_input(data)
		pred, flag, loss, right_number = model.test()
		right_numbers += right_number
		clss_loss_sum += loss
		new_result_context = str(pred.item()) + ';' + str(flag.item())+ '\n'
		f_result.write(new_result_context)
		# print(pred.item(),flag.item())
		if i % 100 == 0:
			print(clss_loss_sum/i,right_numbers / (i+1))
			f_result.close()
			f_result=open(resultfiledir, "a+")

步骤2:矩阵绘制

参考文章:
[1] https://blog.csdn.net/qq_33590958/article/details/103443215

对应代码:

from sklearn.metrics import confusion_matrix    # 生成混淆矩阵的函数
import numpy as np
from matplotlib import pyplot as plt
import pdb

plt.rcParams["font.sans-serif"]=["SimSun"]
'''
首先是从结果文件中读取预测标签与真实标签,然后将读取的标签信息传入python内置的混淆矩阵矩阵函数confusion_matrix(真实标签,
预测标签)中计算得到混淆矩阵,之后调用自己实现的混淆矩阵可视化函数plot_confusion_matrix()即可实现可视化。
三个参数分别是混淆矩阵归一化值,总的类别标签集合,可是化图的标题
'''
 
def plot_confusion_matrix(cm, labels_name, title):
    np.set_printoptions(precision=2)
    # print(cm)
    plt.imshow(cm, interpolation='nearest',cmap='YlOrBr')    # 在特定的窗口上显示图像
    # 显示text
    for first_index in range(len(cm)):    #第几行
        for second_index in range(len(cm[first_index])):    #第几列
            plt.text(first_index, second_index, '%.2f' % cm[first_index][second_index],
            horizontalalignment='center')
    plt.title(title)    # 图像标题
    plt.colorbar()
    num_local = np.array(range(len(labels_name)))    
    plt.xticks(num_local, labels_name, rotation=90)    # 将标签印在x轴坐标上
    plt.yticks(num_local, labels_name)    # 将标签印在y轴坐标上
    plt.ylabel('真实类别')    
    plt.xlabel('预测类别')
    # show confusion matrix
    plt.savefig('./fig/'+title+'.png', format='png')
 
gt = []
pre = []
with open("11c_result.txt", "r") as f:
    for line in f:
        line=line.rstrip()#rstrip() 删除 string 字符串末尾的指定字符(默认为空格)
        words=line.split(';')
        pre.append(int(words[0]))
        gt.append(int(eval(words[1])))
 
cm=confusion_matrix(gt,pre)  #计算混淆矩阵
print('type=',type(cm))
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]    # 归一化
labels = [0,1,2,3,4,5,6,7,8,9,10]  #类别集合
# pdb.set_trace()
plot_confusion_matrix(cm,labels,'隐藏场景分类探测任务的混淆矩阵')  #绘制混淆矩阵图,可视化

相比于文章[1],主要修改了如下几个地方:

  1. 学位论文要求图例中文宋体

plt.rcParams["font.sans-serif"]=["SimSun"]

  1. 为了美观,使用暖色调的colormap

plt.imshow(cm, interpolation='nearest',cmap='YlOrBr')

  1. 在混淆矩阵中写入了概率
    for first_index in range(len(cm)):    #第几行
        for second_index in range(len(cm[first_index])):    #第几列
            plt.text(first_index, second_index, '%.2f' % cm[first_index][second_index],
            horizontalalignment='center')

混淆矩阵绘制结果

在这里插入图片描述

  • 3
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

R.X. NLOS

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值