Python绘制混淆矩阵热力图
用matplotlib绘制混淆矩阵,可以通过改变 imshow 函数中的 cmap 参数来修改颜色。cmap 参数接受一个 colormap 的名字,你可以选择许多不同的 colormap,例如 ‘viridis’, ‘plasma’, ‘inferno’, ‘magma’, ‘cividis’, ‘cool’, ‘hot’ 等等。具体的 colormap 可以参考 matplotlib 的文档。
# 案例1:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, Normalize
conf_arr = [[33, 2, 0, 0, 0, 0, 0, 0, 0, 1, 3],
[3, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 4, 41, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 1, 0, 30, 0, 6, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 38, 10, 0, 0, 0, 0, 0],
[0, 0, 0, 3, 1, 39, 0, 0, 0, 0, 4],
[0, 2, 2, 0, 4, 1, 31, 0, 0, 0, 2],
[0, 1, 0, 0, 0, 0, 0, 36, 0, 2, 0],
[0, 0, 0, 0, 0, 0, 1, 5, 37, 5, 1],
[3, 0, 0, 0, 0, 0, 0, 0, 0, 39, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 38]]
# 自定义渐变颜色
colors = ['#ffffff', '#ffcccc', '#ff6666', '#ff0000', '#990000']
custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', colors)
# 归一化对象,定义最小值和最大值
norm = Normalize(vmin=0, vmax=np.max(conf_arr))
norm_conf = []
for i in conf_arr:
a = 0
tmp_arr = []
a = sum(i, 0)
for j in i:
tmp_arr.append(float(j)/float(a))
norm_conf.append(tmp_arr)
fig = plt.figure()
plt.clf()
ax = fig.add_subplot(111)
ax.set_aspect(1)
res = ax.imshow(np.array(conf_arr), cmap=custom_cmap, norm=norm, interpolation='nearest')
# 获取矩阵的宽度和高度
width, height = np.array(conf_arr).shape
# 使用 range 替换 xrange
for x in range(width):
for y in range(height):
ax.annotate(str(conf_arr[x][y]), xy=(y, x),
horizontalalignment='center',
verticalalignment='center',
color='black')
cb = fig.colorbar(res)
# 调整 alphabet 使其不会超出索引范围
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'[:width]
plt.xticks(range(width), alphabet)
plt.yticks(range(height), alphabet)
plt.savefig('confusion_matrix.png', format='png')
plt.show()
效果图如下:
# 案例2:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
# 创建混淆矩阵数据
intra_inter = np.array([
[84.9, 4.7, 7.3, 0.5, 0.5, 1.6,0.1,0.2],
[4.0, 89.6, 0.6, 4.0, 1.7, 0.1,0.3,0.1],
[6.7, 0.5, 85.7, 1.4, 4.3, 1.0,0.2,0.4],
[2.3, 0.0, 5.7, 87.4, 4.6, 0.0,0.3,0.1],
[1.0, 1.0, 4.9, 1.0, 92.2, 0.0,0.0,0.3],
[0.4, 0.4, 2.2, 0.4, 0.9, 95.1,0.2,0.4],
[0.6, 0.6, 0.6, 2.2, 5.0, 0.1,96.1,0.3],
[2.4, 0.6, 0.6, 7.7, 3.6, 4.5,0.2,95.1]
])
intra_inter_obj = np.array([
[89.1, 1.6, 7.8, 0.5, 0.5, 1.0,0.1,0.2],
[1.7, 93.1, 1.2, 3.5, 0.6, 0.0,0.1,0.2],
[6.7, 1.0, 91.4, 0.5, 0.5, 1.0,0.1,0.2],
[1.1, 1.1, 4.6, 94.3, 4.6, 0.0,0.1,0.2],
[1.0, 1.0, 6.9, 1.0, 90.2, 0.0,0.1,0.2],
[0.4, 0.4, 0.4, 0.4, 0.4, 95.1,0.1,0.2],
[7.7, 3.6, 3.6, 7.7, 2.7, 0.1,93.1,0.1],
[2.4, 0.6, 0.6, 7.7, 3.6, 4.5,0.2,95.1]
])
labels = ["Right set", "Right spike", "Right pass", "Right winpoint", "Left winpoint", "Left pass", "Left spike", "Left set"]
# 每个矩阵的维度
n_labels = len(labels)
cell_size = 1.5 # 每个小块的大小
# 创建图形
fig, ax = plt.subplots(1, 2, figsize=(n_labels * cell_size * 2, n_labels * cell_size))
# 绘制第一个混淆矩阵
sns.heatmap(intra_inter, annot=True, fmt=".1f", cmap="Blues", ax=ax[0], xticklabels=labels, yticklabels=labels, cbar=False)
ax[0].set_title("Intra+Inter")
ax[0].set_xlabel('')
ax[0].set_ylabel('')
# 绘制第二个混淆矩阵
sns.heatmap(intra_inter_obj, annot=True, fmt=".1f", cmap="Blues", ax=ax[1], xticklabels=labels, yticklabels=labels, cbar=True)
ax[1].set_title("Intra+Inter+Object")
ax[1].set_xlabel('')
ax[1].set_ylabel('')
# 调整底部标签的倾斜角度
plt.setp(ax[0].get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
plt.setp(ax[1].get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
# 获取颜色条对象
cbar = ax[1].collections[0].colorbar
# 调整颜色条的位置和大小
cbar.ax.set_position([0.92, ax[1].get_position().y0, 0.02, ax[1].get_position().height])
# 调整布局
plt.tight_layout(rect=[0, 0, 0.9, 1])
plt.show()
效果图如下: