代码如下:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
def plotCM(classes, matrix, savname):
"""classes: a list of class names"""
# Normalize by row
matrix = matrix.astype(np.float)
linesum = matrix.sum(1)
linesum = np.dot(linesum.reshape(-1, 1), np.ones((1, matrix.shape[1])))
matrix /= linesum
# plot
plt.switch_backend('agg')
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(matrix)
fig.colorbar(cax)
ax.xaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(1))
for i in range(matrix.shape[0]):
ax.text(i, i, str('%.2f' % (matrix[i, i] * 100)), va='center', ha='center')
ax.set_xticklabels([''] + classes, rotation=90)
ax.set_yticklabels([''] + classes)
# save
plt.savefig(savname)
if __name__ == '__main__':
classes = ["A", "B", "C", "D", "E", "F", "G", "H"]
matrix = np.array([[23, 1, 2, 52, 5, 0, 1, 0],
[0, 107, 2, 13, 18, 0, 2, 0],
[4, 23, 15, 6, 3, 0, 1, 0],
[12, 73, 1, 114, 0, 0, 0, 0],
[1, 0, 0, 0, 100, 0, 10, 0],
[0, 5, 0, 4, 7, 0, 0, 0],
[7, 10, 2, 31, 0, 0, 150, 0],
[0, 0, 0, 0, 0, 0, 5, 0]]
)
savename = "test.png"
plotCM(classes, matrix, savename)
结果图如下: