文章转载于知乎saya:https://zhuanlan.zhihu.com/p/35261741
可视化混淆矩阵
混淆矩阵是我们用来理解分类模型性能的表格。 这有助于我们理解如何将测试数据分类到不同的类中。 当我们想微调我们的算法时,我们需要了解在做出这些更改之前数据是如何被错误分类的。 有些种类比其他课程更糟糕,混淆矩阵将帮助我们理解这一点。 我们来看看下图:
在前面的图表中,我们可以看到我们如何将数据分类到不同的类中。 理想情况下,我们希望所有非对角线元素都为0.这表明完美的分类!让我们考虑
class 0
。总体而言,52个项目实际上属于
class 0
。如果我们总结第一行中的数字,则得到52。 现在,这些项目中有45项被正确预测,但是分类器说其中4项属于
class 1
,3项属于
class 2
。我们可以对其余两行应用相同的分析。值得注意的是,来自
class 1
的11个项被错误分类为
class 0
。这构成了该类中约16%的数据点。 这是我们可以用来优化模型的见解。
-
导入必要的数据库
import numpy as npimport matplotlib.pyplot as pltfrom sklearn.metrics import confusion_matrix
-
生成数据调用confusion_matrix模块
y_true = [1, 0, 0, 2, 1, 0, 3, 3, 3]y_pred = [1, 1, 0, 2, 1, 0, 1, 3, 3]confusion_mat = confusion_matrix(y_true, y_pred)
-
定义显示的结
# Show confusion matrixdef plot_confusion_matrix(confusion_mat): plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.cm.gray) plt.title('Confusion matrix') plt.colorbar() tick_marks = np.arange(4) plt.xticks(tick_marks, tick_marks) plt.yticks(tick_marks, tick_marks) plt.ylabel('True label') plt.xlabel('Predicted label') plt.show()
我们使用imshow函数来绘制混淆矩阵。 其他功能都很简单! 我们只需使用相关功能设置标题,颜色条,标记和标签。 tick_marks参数的范围从0到3,因为我们在数据集中有四个不同的标签。 np.arangefunction给了我们这个numpy数组。
-
进行显示结果
plot_confusion_matrix(confusion_mat)
输出结果:
对角线的颜色很强烈,我们希望它们的颜色变得深。 浅黄色表示零。 非对角线空间中有几个绿色,表示错误分类。 例如,当真实标签为0时,预测标签为1,如我们在第一行中所看到的。 事实上,所有的错误分类属于第一类,因为第二列包含三个非零的行。 从图中很容易看到这一点。
-
提取性能报告
# Print classification reportfrom sklearn.metrics import classification_reporttarget_names = ['Class-0', 'Class-1', 'Class-2', 'Class-3']print (classification_report(y_true, y_pred, target_names=target_names))
输出的结果:
precision recall f1-score support Class-0 1.00 0.67 0.80 3 Class-1 0.50 1.00 0.67 2 Class-2 1.00 1.00 1.00 1 Class-3 1.00 0.67 0.80 3avg / total 0.89 0.78 0.79 9
结果分析