from sklearn.metrics import confusion_matrix
import seaborn as sns
def plot_cm(labels, predictions):
cm = confusion_matrix(labels, predictions)
plt.figure(figsize=(5,5))
sns.heatmap(cm, annot=True, fmt="d")
plt.title('Confusion matrix @p')
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
test_predictions = test_model.predict(test_features, batch_size=16)
#print(test_predictions)
plot_cm(test_labels,test_predictions)
在生产一个多分类的混淆矩阵时会出现报错:
ValueError: Classification metrics can't handle a mix of multilabel-indicator and continuous-multioutput targets
报错位置是最后一句调用混淆矩阵绘制函数时产生的,test_labels 划分类别时采用的是one-hot编码形式,如下:
[[1. 0. 0. 0. 0.] [1. 0. 0. 0. 0.] [1. 0. 0. 0. 0.] [1. 0. 0. 0. 0.] [1. 0. 0. 0. 0.] [1. 0. 0. 0. 0.] [1. 0. 0. 0. 0.] [1. 0. 0. 0. 0.] [1. 0. 0. 0. 0.] [0. 1