使用scikit-learn中的metrics.plot_confusion_matrix混淆矩阵函数分析分类器的误差来源

在前面的文章中介绍了使用scikit-learn绘制ROC曲线使用scikit-learn绘制误差学习曲线,通过绘制ROC曲线和误差学习曲线可以让我们知道我们的模型现在整体上做的有多好,可以判断模型的状态是过拟合还是欠拟合,从而确定后续的优化方向。但是绘制学习曲线的方法只能让我们从整体上了解模型的性能,并不能具体展示具体的误差来源。在吴恩达老师的视频中,多次强调误差分析的重要性,就是针对模型处理出错的样本进行重点研究分析,然后选择可能的优化方向。今天这篇短文就来讲一下针对分类问题,如何使用scikit-learn工具进行简单的误差分析。

本示例使用SVM分类器,对手写数字进行分类。

1、加载数据集,并划分训练集和验证集

%matplotlib inline
from sklearn import datasets
from sklearn.svm import SVC
import warnings
warnings.filterwarnings("ignore")

from sklearn.model_selection import train_test_split

digits = datasets.load_digits()
X, y = digits.data, digits.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=100)

X_train.shape, X_test.shape

((1257, 64), (540, 64))

2、在训练集上训练一个SVM分类器

svc_clf = SVC(kernel="linear", C=1.0)
svc_clf.fit(X_train, y_train)

3、绘制分类器在测试集上的结果混淆矩阵

from sklearn import metrics

metrics.plot_confusion_matrix(svc_clf, X_test, y_test)

混淆矩阵的横坐标表示模型的预测结果,纵坐标表示真实结果。淆矩阵的每行元素表示类别i(行编号)被预测为类别j(列编号)的数量。所以在上图中,除对角线以外的非0数值都是被误分类的样本数量。比如第9行的第6列的数值为2,表示有2个数字8倍误分类成了数字5。通过混淆矩阵可以方便的看出模型在矩阵那些类别之间的误分类比较严重,从而有利于确定下一步的优化方向。

在上图中,除对角线以外的非0元素总和是:1 + 1 + 1 + 1 + 1 + 1 + 2 = 8,那么该分类模型在测试集上的准确率应该是(540 - 8)/540 = 0.985。使用metrics.accuracy函数计算一下看看是否一致:

metrics.accuracy_score(y_test, svc_clf.predict(X_test))

0.9851851851851852。结果与预期一致。

 

 

参考:scikit-learn中文翻译

参考:https://en.wikipedia.org/wiki/Confusion_matrix

  • 6
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值