sklearn中混淆矩阵(confusion_matrix函数)的理解与使用

本文详细介绍了混淆矩阵的概念,它是机器学习中用于评估分类模型性能的重要工具。混淆矩阵通过比较预测值与实际值,提供了直观的分类精度视图,包括总体精度、制图精度和用户精度等。文章还解释了如何使用sklearn库中的混淆矩阵函数,并提供了具体的代码示例。
混淆矩阵

百度百科的定义:

混淆矩阵(confusion
matrix)也称误差矩阵,是表示精度评价的一种标准格式,用n行n列的矩阵形式来表示。具体评价指标有总体精度、制图精度、用户精度等,这些精度指标从不同的侧面反映了图像分类的精度。
在人工智能中,混淆矩阵(confusion
matrix)是可视化工具,特别用于监督学习,在无监督学习一般叫做匹配矩阵。在图像精度评价中,主要用于比较分类结果和实际测得值,可以把分类结果的精度显示在一个混淆矩阵里面。混淆矩阵是通过将每个实测像元的位置和分类与分类图像中的相应位置和分类相比较计算的。

在机器学习领域,混淆矩阵(confusion matrix),又称为可能性表格或是错误矩阵。
它是一种特定的矩阵用来呈现算法性能的可视化效果,通常是监督学习(非监督学习,通常用匹配矩阵:matching matrix)。
其每一列代表预测值,每一行代表的是实际的类别。
这个名字来源于它可以非常容易的表明多个类别是否有混淆(也就是一个class被预测成另一个class)。

简单的图解:(这张图真的非常好理解)
在这里插入图片描述

使用

官方文档中给出的用法是:
sklearn.metrics.confusion_matrix(y_true, y_pred, labels=None, sample_weight=None)

y_true: 是样本真实分类结果
y_pred: 是样本预测分类结果
labels:是所给出的类别,通过这个可对类别进行选择
sample_weight : 样本权重

实现例子:

from sklearn.metrics import confusion_matrix

y_true=[2,1,0,1,2,0]
y_pred=[2,0,0,1,2,1]

C=confusion_matrix(y_true, y_pred)

结果:

array([[1, 1, 0],
       [1, 1, 0],
       [1, 0, 2]])

下面是官方文档上的一个例子

y_true = ["cat", "ant", "cat", "cat", "ant", "bird"]
y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"]
confusion_matrix(y_true, y_pred, labels=["ant", "bird", "cat"])

运行结果

array([[2, 0, 0],
       [0, 0, 1],
       [1, 0, 2]])
### 如何使用 `sklearn` 输出模型在测试集上的混淆矩阵 为了计算并展示分类模型在测试集上的混淆矩阵,可以利用 scikit-learn 库中的 `confusion_matrix` 函数[^2]。此过程涉及几个关键步骤的操作。 下面是一个完整的 Python 示例代码,展示了如何加载必要的库、训练一个简单的分类器以及输出该分类器针对给定测试数据的混淆矩阵: ```python from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt import seaborn as sns import numpy as np # 加载示例数据集 data = load_iris() X, y = data.data, data.target # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) # 训练简单分类器 (这里以K近邻为例) from sklearn.neighbors import KNeighborsClassifier clf = KNeighborsClassifier(n_neighbors=3).fit(X_train, y_train) # 预测测试集标签 y_pred = clf.predict(X_test) # 使用 confusion_matrix 函数获取混淆矩阵 cm = confusion_matrix(y_test, y_pred) print("Confusion Matrix:") print(cm) # 可视化混淆矩阵 plt.figure(figsize=(8,6)) sns.heatmap(cm, annot=True, fmt='d', cmap="Blues", xticklabels=data.target_names, yticklabels=data.target_names) plt.ylabel('True label') plt.xlabel('Predicted label') plt.title('Confusion Matrix Heatmap') plt.show() ``` 上述脚本不仅实现了基本的功能需求——即通过调用 `confusion_matrix()` 来获得预测结果实际类别之间的对比情况;还进一步借助 Matplotlib 和 Seaborn 工具包来绘制热力图形式的可视化表示,使得理解不同类别的误判状况更加直观[^1]。
评论 16
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值