论文中神经网络混淆矩阵图片打印
导入必要的库函数
import seaborn as sns
from sklearn.metrics import confusion_matrix
from PIL import Image
做好数据集并对模型进行训练权重加载
计算生成混淆矩阵图片并保存
predict_imgs_list = []
labels_list = []
with torch.no_grad():
for data in test_transforms_loader:
img, label = data
img = img.to(device)
labels_list.append(label)
predict_img = net(img).argmax(dim=1)
predict_imgs_list.append(predict_img.cpu().numpy())
labels = list(itertools.chain.from_iterable(labels_list))
outputs = list(itertools.chain.from_iterable(predict_imgs_list))
confusion_matrix = confusion_matrix(labels, outputs).astype('int')
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix')
plt.show()