绘制混淆矩阵热力图

Python绘制混淆矩阵热力图


用matplotlib绘制混淆矩阵,可以通过改变 imshow 函数中的 cmap 参数来修改颜色。cmap 参数接受一个 colormap 的名字,你可以选择许多不同的 colormap,例如 ‘viridis’, ‘plasma’, ‘inferno’, ‘magma’, ‘cividis’, ‘cool’, ‘hot’ 等等。具体的 colormap 可以参考 matplotlib 的文档。

# 案例1:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, Normalize

conf_arr = [[33, 2, 0, 0, 0, 0, 0, 0, 0, 1, 3],
            [3, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 4, 41, 0, 0, 0, 0, 0, 0, 0, 1],
            [0, 1, 0, 30, 0, 6, 0, 0, 0, 0, 1],
            [0, 0, 0, 0, 38, 10, 0, 0, 0, 0, 0],
            [0, 0, 0, 3, 1, 39, 0, 0, 0, 0, 4],
            [0, 2, 2, 0, 4, 1, 31, 0, 0, 0, 2],
            [0, 1, 0, 0, 0, 0, 0, 36, 0, 2, 0],
            [0, 0, 0, 0, 0, 0, 1, 5, 37, 5, 1],
            [3, 0, 0, 0, 0, 0, 0, 0, 0, 39, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 38]]

# 自定义渐变颜色
colors = ['#ffffff', '#ffcccc', '#ff6666', '#ff0000', '#990000']
custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', colors)

# 归一化对象,定义最小值和最大值
norm = Normalize(vmin=0, vmax=np.max(conf_arr))

norm_conf = []
for i in conf_arr:
    a = 0
    tmp_arr = []
    a = sum(i, 0)
    for j in i:
        tmp_arr.append(float(j)/float(a))
    norm_conf.append(tmp_arr)

fig = plt.figure()
plt.clf()
ax = fig.add_subplot(111)
ax.set_aspect(1)
res = ax.imshow(np.array(conf_arr), cmap=custom_cmap, norm=norm, interpolation='nearest')

# 获取矩阵的宽度和高度
width, height = np.array(conf_arr).shape

# 使用 range 替换 xrange
for x in range(width):
    for y in range(height):
        ax.annotate(str(conf_arr[x][y]), xy=(y, x),
                    horizontalalignment='center',
                    verticalalignment='center',
                    color='black')

cb = fig.colorbar(res)

# 调整 alphabet 使其不会超出索引范围
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'[:width]
plt.xticks(range(width), alphabet)
plt.yticks(range(height), alphabet)

plt.savefig('confusion_matrix.png', format='png')
plt.show()

效果图如下:
在这里插入图片描述

# 案例2:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# 创建混淆矩阵数据
intra_inter = np.array([
    [84.9, 4.7, 7.3, 0.5, 0.5, 1.6,0.1,0.2],
    [4.0, 89.6, 0.6, 4.0, 1.7, 0.1,0.3,0.1],
    [6.7, 0.5, 85.7, 1.4, 4.3, 1.0,0.2,0.4],
    [2.3, 0.0, 5.7, 87.4, 4.6, 0.0,0.3,0.1],
    [1.0, 1.0, 4.9, 1.0, 92.2, 0.0,0.0,0.3],
    [0.4, 0.4, 2.2, 0.4, 0.9, 95.1,0.2,0.4],
    [0.6, 0.6, 0.6, 2.2, 5.0, 0.1,96.1,0.3],
    [2.4, 0.6, 0.6, 7.7, 3.6, 4.5,0.2,95.1]
])

intra_inter_obj = np.array([
    [89.1, 1.6, 7.8, 0.5, 0.5, 1.0,0.1,0.2],
    [1.7, 93.1, 1.2, 3.5, 0.6, 0.0,0.1,0.2],
    [6.7, 1.0, 91.4, 0.5, 0.5, 1.0,0.1,0.2],
    [1.1, 1.1, 4.6, 94.3, 4.6, 0.0,0.1,0.2],
    [1.0, 1.0, 6.9, 1.0, 90.2, 0.0,0.1,0.2],
    [0.4, 0.4, 0.4, 0.4, 0.4, 95.1,0.1,0.2],
    [7.7, 3.6, 3.6, 7.7, 2.7, 0.1,93.1,0.1],
    [2.4, 0.6, 0.6, 7.7, 3.6, 4.5,0.2,95.1]
])

labels = ["Right set", "Right spike", "Right pass", "Right winpoint", "Left winpoint", "Left pass", "Left spike", "Left set"]

# 每个矩阵的维度
n_labels = len(labels)
cell_size = 1.5  # 每个小块的大小

# 创建图形
fig, ax = plt.subplots(1, 2, figsize=(n_labels * cell_size * 2, n_labels * cell_size))

# 绘制第一个混淆矩阵
sns.heatmap(intra_inter, annot=True, fmt=".1f", cmap="Blues", ax=ax[0], xticklabels=labels, yticklabels=labels, cbar=False)
ax[0].set_title("Intra+Inter")
ax[0].set_xlabel('')
ax[0].set_ylabel('')

# 绘制第二个混淆矩阵
sns.heatmap(intra_inter_obj, annot=True, fmt=".1f", cmap="Blues", ax=ax[1], xticklabels=labels, yticklabels=labels, cbar=True)
ax[1].set_title("Intra+Inter+Object")
ax[1].set_xlabel('')
ax[1].set_ylabel('')

# 调整底部标签的倾斜角度
plt.setp(ax[0].get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
plt.setp(ax[1].get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

# 获取颜色条对象
cbar = ax[1].collections[0].colorbar

# 调整颜色条的位置和大小
cbar.ax.set_position([0.92, ax[1].get_position().y0, 0.02, ax[1].get_position().height])

# 调整布局
plt.tight_layout(rect=[0, 0, 0.9, 1])
plt.show()

效果图如下:
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ccwRadar

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值