对于二分类问题,一种最全面的表示方法是使用混淆矩阵(confusion matrix),我们利用混淆矩阵来检查前面刚刚LogisticRegression的预测结果,前面我们已经将预测结果保存在了pred_logreg中,这里不再重复代码:
from sklearn.metrics import confusion_matrix
confusion = confusion_matrix(y_test, pred_logreg)
print("Confusion matrix:{}".format(confusion))
运行后结果为:
Confusion matrix:
[[401 2]
[ 8 39]]
有上述运行结果可见,得到的是一个2*2数组,其中行对应于真实的类别,列对应于预测的类别。数组中每个元素给出属于该行的类别(这里是“非9”和“9”)的样本被分类到这列对应类别中的数量,如下图进行说明:
mglearn.plots.plot_confusion_matrix_illustration()
“9与其他”分类 任务的混淆矩阵
混淆矩阵主对角线(对于一个矩阵A来说,主对角线为A[i, i])上的元素对应于正确的分类,而其他元素则告诉我们一个类别中有多少个样本被错误的划分到其他类别中。
如果我们将“9"作为正类,那么就可以将混淆矩阵的元素与前面介绍的假正例(false positive)和假反例(false negative)两个俗语联系起来。
下面我们将真证例,真反例,假正例和假反例分别简写为TN、TP、FN、FP,然后就可以得到如下图的混淆矩阵解释:
二分类混淆矩阵
下面我们用混淆矩阵来比较前面拟合过的模型(两个虚拟模型、决策树和Logistic回归):
print("Most frequency class:")
print(confusion_matrix(y_test, pred_most_frequent))
print("Dummy model:")
print(confusion_matrix(y_test, pred_dummy))
print("Decision tree:")
print(confusion_matrix(y_test, pred_tree))
print("logistic regression")
print(confusion_matrix(y_test, pred_logreg))
打印出来的结果为:
Most frequency class:
[[403 0]
[ 47 0]]
Dummy model:
[[361 42]
[ 40 7]]
Decision tree:
[[390 13]
[ 24 23]]
logistic regression
[[401 2]
[ 8 39]]
观察混淆矩阵,很明显可以看出pred_most_frequent有问题,因为它总是预测同一个类别,其他值为0。其次是pred_dumy,真正正例很少(7个),假正例个数比真正例还多。
决策树的预测比虚拟预测更有意义,假正例和真正例大体相同;当然,效果最好的还是Logistic回归了,不管是真正例还是真反例都具有明显的优势,说明预测效果好。
不过我们不得不指出的是,检查整个混淆矩阵还是比较麻烦的,我们只能通过人工的方式去检查对应的效果,下节我们将讨论通过精度、准确率、召回率等方式来总结混淆矩阵中包含的信息。