import itertools
import numpy as np
import matplotlib.pyplot as plt
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
if normalize:
# normalizing operation
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # get line sum,and expand one dimension
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
# cmap means the color
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
# color bar on the right
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
# print the number in cmt[j,i]
plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
使用:
print(cm)
names = (
'T-shirt/top',
'Trouser',
'Pullover',
'Dress',
'Coat',
'Sandal',
'Shirt',
'Sneaker',
'Bag',
'Ankle boot')
plt.figure(figsize=(10, 10))
plotcm.plot_confusion_matrix(cm, names)
cm矩阵:
[[5624 13 68 99 10 3 121 2 60 0]
[ 14 5901 4 61 4 1 8 0 7 0]
[ 151 6 4721 60 712 1 314 0 35 0]
[ 295 42 22 5382 154 0 88 1 16 0]
[ 20 9 444 230 4892 0 378 1 26 0]
[ 0 1 1 1 0 5763 0 175 16 43]
[1559 18 579 112 466 1 3192 1 72 0]
[ 0 0 0 1 0 21 0 5850 4 124]
[ 19 4 17 26 15 2 37 5 5875 0]
[ 0 0 2 1 0 23 0 208 8 5758]]
绘制效果:
学习自deeplizard.com
数据集来自zelando的 FashionMNIST
简单CNN的训练结果