from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
def plotCM(classes, matrix, savname):
"""classes: a list of class names"""
# Normalize by row
# matrix = matrix.astype(np.float)
# linesum = matrix.sum(1)
# linesum = np.dot(linesum.reshape(-1, 1), np.ones((1, matrix.shape[1])))
# matrix /= linesum
# plot
plt.switch_backend('agg')
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(matrix)
fig.colorbar(cax)
ax.xaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(1))
for i in range(matrix.shape[0]):
for j in range(matrix.shape[1]):
print(matrix[i, j])
ax.text(i, j, str('%.2f' % (matrix[i, j] * 100)), va='center', ha='center')
ax.set_xticklabels([''] + classes, rotation=90)
ax.set_yticklabels([''] + classes)
#save
# plt.imshow()
plt.savefig(savname)
matrix = confusion_matrix(res['label'], res['pre_label'])
plotCM(['0','1'], matrix, 'res')