【机器学习】使用scikitLearn可视化混淆矩阵进行分类误差分析

36 篇文章 5 订阅
27 篇文章 3 订阅

分类相关导航:
【机器学习】分类任务以mnist为例,数据集准备及预处理
【机器学习】scikitLearn分类任务以mnist为例,训练二分类器并衡量性能指标:ROC及PR曲线
【机器学习】scikitLearn分类:鉴别二分类、多分类、多标签及多输出的分类任务

1.绘制混淆矩阵:

#1.建立模型
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(max_iter=1000, tol=1e-3, random_state=42)
sgd_clf.fit(X_train, y_train_5)
#2.绘制混淆矩阵
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
#3.给定绘图用函数:
def plot_confusion_matrix(matrix):
    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111)
    #matplotlib.pyplot.matshow()函数用于在新图形窗口中将数组表示为矩阵
    cax = ax.matshow(matrix)
    fig.colorbar(cax)
#4.调用函数进行绘图
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()

在这里插入图片描述
图中颜色越白的区域,图片越多。对角线上是被正确分类的图片。

2.更改矩阵绘制方式,将注意力焦点集中在错误分类的值:

#求出每行的和
row_sums = conf_mx.sum(axis=1, keepdims=True)
#算出每行每格所占的比例,行代表真实值,这个比例代表真值对应的不同预测值
norm_conf_mx = conf_mx / row_sums
#将矩阵对角线,理论上最多的值全部置为0,增加其余位置色块的对比度
np.fill_diagonal(norm_conf_mx, 0)
#使用matshow进行矩阵展示
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
plt.show()

在这里插入图片描述
如图所示,越亮的色块错误越多,如真值5被错分为真值8的情况最多,可通过对图像进行预处理,或对5及8样本数增广的形式,提高对应分类的正确率。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

颢师傅

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值