【Sklearn-Bug驯化-混淆矩阵】成功Sklearn中plot_confusion_matrix出现ImportError: cannot import name ‘plot_confusion_matrix’ from ‘sklearn.metrics’
本次修炼方法请往下查看
🌈 欢迎莅临我的个人主页 👈这里是我工作、学习、实践 IT领域、真诚分享 踩坑集合,智慧小天地!
🎇 免费获取相关内容文档关注:微信公众号,发送 pandas 即可获取
🎇 相关内容视频讲解 B站
🎓 博主简介:AI算法驯化师,混迹多个大厂搜索、推荐、广告、数据分析、数据挖掘岗位 个人申请专利40+,熟练掌握机器、深度学习等各类应用算法原理和项目实战经验。
🔧 技术专长: 在机器学习、搜索、广告、推荐、CV、NLP、多模态、数据分析等算法相关领域有丰富的项目实战经验。已累计为求职、科研、学习等需求提供近千次有偿|无偿定制化服务,助力多位小伙伴在学习、求职、工作上少走弯路、提高效率,近一年好评率100% 。
📝 博客风采: 积极分享关于机器学习、深度学习、数据分析、NLP、PyTorch、Python、Linux、工作、项目总结相关的实用内容。
下滑查看解决方法
🎯 1. 问题描述
在使用scikit-learn-1.5.1
版本来计算分类模型的混淆矩阵时,通过下面的代码如下如下的问题,具体的代码为:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import plot_confusion_matrix
# 假设有一个实际标签和预测标签
y_true = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])
y_pred = np.array([0, 0, 1, 0, 1, 1, 0, 1, 2])
# 使用plot_confusion_matrix绘制混淆矩阵
class_names = ['class 0', 'class 1', 'class 2']
confusion_matrix = plot_confusion_matrix(y_true, y_pred, display_labels=class_names, normalize='true')
# 设置图表标题
confusion_matrix.ax_.set_title('Normalized Confusion Matrix')
# 显示图表
plt.show()
上述的代码会报如下的错误,具体为:
💡 2. 解决方法
2.1 降低版本
对于上述的问题是由于sklearn高版本抛弃了plot_confusion_matrix函数的使用,如果方便降低版本的话,可以通过如下的命令将sklearn的版本降低就可以了,具体的命令如下所示:
pip install scikit_learn==0.24.1
2.2 替换函数ConfusionMatrixDisplay
在sklearn的版本更新之后,可以使用最新的函数进行平替,具体的写法如下所示:
from sklearn.metrics import ConfusionMatrixDisplay
# 假设有一个实际标签和预测标签
# 使用plot_confusion_matrix绘制混淆矩阵
class_names = ['class 0', 'class 1', 'class 2']
y_true = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])
y_pred = np.array([0, 0, 1, 0, 1, 1, 0, 1, 2])
ConfusionMatrixDisplay.from_predictions(y_true, y_pred,display_labels=class_names, normalize='true')
具体的结果如下所示:
💡 3. 参数介绍
ConfusionMatrixDisplay常用的参数如下所示:
cmap:
混淆矩阵的颜色映射。可以是一个字符串名称(如"Blues"、"Greens"等),也可以是一个matplotlib.colors.Colormap对象。ax:
绘制混淆矩阵的坐标轴。如果不提供该参数,默认使用当前活动的坐标轴。include_values:
指示是否在混淆矩阵图中显示每个单元格的数值。默认为True。normalize:
指示是否在混淆矩阵中显示比例而不是原始计数。默认为False。xticks_rotation:
标签的旋转角度。默认为None,表示不旋转。format:
数值的格式化字符串。默认为".2f",表示保留两位小数。display_labels:
类别标签,是一个一维数组,包含了每个类别的名称。如果不提供该参数,默认使用数字0到类别数量-1作为类别标签。