使用python matplotlib绘制混淆矩阵
今天使用了python matplotlib包,绘制混淆矩阵。基本代码参考官网教程,在此基础上增加了格网显示。
代码说明:
cm - 混淆矩阵的数值, 是一个二维numpy数组
classes - 各个类别的标签(label)
title - 图片标题
cmap - 颜色图
def plot_Matrix(cm, classes, title=None, cmap=plt.cm.Blues):
plt.rc('font',family='Times New Roman',size='8') # 设置字体样式、大小
# 按行进行归一化
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
str_cm = cm.astype(np.str).tolist()
for row in str_cm:
print('\t'.join(row))
# 占比1%以下的单元格,设为0,防止在最后的颜色中体现出来
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
if int(cm[i, j]*100 + 0.5) == 0:
cm[i, j]=0
fig, ax = plt.subplots()
im = ax.im