摘要这一篇简单介绍一下混淆矩阵的计算和绘制,混淆矩阵可以用来判断模型预测的结果。
介绍
这一篇主要介绍一下绘制混淆矩阵(confusion matrix)的方式。通常在看model的效果的时候,我们会使用混淆矩阵来进行检测。
主要参考资料 :
具体绘制方式
混淆矩阵的计算
混淆矩阵就是我们会计算最后分类错误的个数, 如计算将class1分为class2的个数,以此类推。
我们可以使用下面的方式来进行混淆矩阵的计算。
# 绘制混淆矩阵
def confusion_matrix(preds, labels, conf_matrix):
preds = torch.argmax(preds, 1)
for p, t in zip(preds, labels):
conf_matrix[p, t] += 1
return conf_matrix
conf_matrix = torch.zeros(10, 10)
for data, target in test_loader:
output = fullModel(data.to(device))
conf_matrix = confusion_matrix(output, target, conf_matrix)
最后得到的conf_matrix就是混淆矩阵的值。
混淆矩阵的可视化
有了上面的混淆矩阵中具体的值,下面就是进行可视化的步骤。可视化我们使用seaborn来进行完成。因为我这里conf_matrix的值是tensor, 所以需要先转换为Numpy.
import seaborn as sn
df_cm = pd.DataFrame(conf_matrix.numpy(),
index = [i for i in list(Attack2Index.keys())],
columns = [i for i in list(Attack2Index.keys())])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True, cmap="BuPu")
最终的混淆矩阵的图如下所示:
混淆矩阵的可视化(进行美化)
当然, 我们还可以对混淆矩阵做更多的处理, 使得显示的时候能更加好看一些. 下面的绘制混淆矩阵的函数我是在下面的链接里看到的, 最终的效果很是不错。
这里简单贴一下代码,可以方便直接进行使用。
import itertools
# 绘制混淆矩阵
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
Input
- cm : 计算出的混淆矩阵的值
- classes : 混淆矩阵中每一行每一列对应的列
- normalize : True:显示百分比, False:显示个数
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
测试数据如下所示:
cnf_matrix = np.array([[8707, 64, 731, 164, 45],
[1821, 5530, 79, 0, 28],
[266, 167, 1982, 4, 2],
[691, 0, 107, 1930, 26],
[30, 0, 111, 17, 42]])
attack_types = ['Normal', 'DoS', 'Probe', 'R2L', 'U2R']
我们分别测试normalize=True/False的效果。
plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Normalized confusion matrix')
plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=False, title='Normalized confusion matrix')