【Python绘图】两种绘制混淆矩阵的方式 (ConfusionMatrixDisplay(), imshow()) 以及两种好看的colorbar

在机器学习领域,混淆矩阵是一个评估分类模型性能的重要工具。它不仅展示了模型预测的准确性,还揭示了模型在不同类别上的表现。本文介绍两种在Python中绘制混淆矩阵的方法:ConfusionMatrixDisplay()imshow(),以及两种好看的colorbar:coolwarm_rGnBu 以增强可视化效果。



ConfusionMatrixDisplay()

ConfusionMatrixDisplay() 是一个来自 scikit-learn 库的类,用于可视化混淆矩阵。

sklearn.metrics.ConfusionMatrixDisplay 的官方社区描述:

基本用法:

ConfusionMatrixDisplay 可以通过以下方式创建:

from sklearn.metrics import ConfusionMatrixDisplay

# 假设 cm 是一个混淆矩阵
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()

参数和方法:

  • confusion_matrix: 参数,一个形状为 (n_classes, n_classes) 的 ndarray,表示混淆矩阵。
  • display_labels: 参数,一个形状为 (n_classes,) 的 ndarray,默认为 None。用于设置绘图时的标签。如果为 None,则显示标签从 0 到 n_classes - 1。
  • plot(): 方法,绘制混淆矩阵的可视化。

示例:

在这里插入图片描述

在这里插入图片描述

示例代码:

from sklearn.metrics import ConfusionMatrixDisplay
import os
import matplotlib.pyplot as plt

import numpy as np
import numpy.random as npr
npr.seed(0)

# Save path
save_path = './plot'
os.makedirs(save_path, exist_ok=True)

# Generate random data 0~1
n = 10
data = npr.rand(n, n) * 0.8
for i in range(n):
    data[i, i] = 1.0

# Plot confusion matrix
fig, ax = plt.subplots(figsize=(8, 8))

cm = ConfusionMatrixDisplay(data, display_labels=np.arange(n))
cm.plot(ax=ax, cmap="GnBu", include_values=False, xticks_rotation=90)  # GnBu, coolwarm_r

ax.set_xlabel('Trials', fontsize=20)
ax.set_ylabel('Trials', fontsize=20)

plt.title(f'Confusion matrix', fontsize=30)
plt.tight_layout()

plt.savefig(f'{save_path}/confu_mat_1-2.png', dpi=300)
plt.show()


imshow()

imshow() 是一个来自 Matplotlib 库的函数,用于在图形用户界面(GUI)中显示图像。这个函数可以处理多种类型的图像数据,包括灰度图和彩色图,是 Matplotlib 中用于图像显示的基础函数之一。

matplotlib.pyplot.imshow 的官方描述:https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html

基本用法:

import matplotlib.pyplot as plt
import numpy as np

# 创建一个随机数组作为图像数据
image_data = np.random.rand(10, 10)

# 使用 imshow() 显示图像
plt.imshow(image_data)
plt.colorbar()  # 显示颜色条
plt.show()

参数:

imshow() 函数接受多个参数,以下是一些常用的参数:

  • X: 图像数据,可以是 2D 数组(灰度图)或 3D 数组(彩色图)。
  • cmap: 颜色映射表,用于定义颜色。例如,cmap=‘gray’ 表示灰度图,cmap=‘viridis’ 是一种常用的彩色映射。
  • norm: 归一化对象,用于调整数据值到 [0, 1] 范围。
  • aspect: 图像的纵横比,可以是 ‘auto’、‘equal’ 或一个数值。
  • interpolation: 插值方法,用于定义图像的缩放方式,如 ‘nearest’、‘bilinear’、‘bicubic’ 等。
  • alpha: 图像的透明度。

imshow() 返回一个 AxesImage 对象,这个对象包含了图像的显示信息,可以用来进一步定制图像的显示效果。

示例:

在这里插入图片描述

在这里插入图片描述

  • ConfusionMatrixDisplay()内置函数定义了所绘制的混淆矩阵必须为方针,而imshow()可以绘制行列数不等的矩形:

在这里插入图片描述

在这里插入图片描述

示例代码:

from mpl_toolkits.axes_grid1 import make_axes_locatable

import os
import matplotlib.pyplot as plt

import numpy as np
import numpy.random as npr
npr.seed(0)

# Save path
save_path = './plot'
os.makedirs(save_path, exist_ok=True)

# Generate random data 0~1
m = 6
n = 10
data = npr.rand(m, n) * 0.8
if m == n:
    for i in range(n):
        data[i, i] = 1.0

fig, ax = plt.subplots(figsize=(n, m))
cm = ax.imshow(data, cmap='coolwarm_r', interpolation="nearest", vmin=0.0, vmax=1.0)  # coolwarm_r, GnBu

# # 绘制一条对角线
# ax.plot([-0.5, n + 0.5], [-0.5, n + 0.5], color='black', alpha=0.2)

ax.set_xticks(np.arange(n))
ax.set_yticks(np.arange(m))

ax.set_xticklabels(np.arange(n), fontsize=15, rotation=90)
ax.set_yticklabels(np.arange(m), fontsize=15)

plt.xlabel('N', fontsize=20)
plt.ylabel('M', fontsize=20)

plt.title(f'Confusion matrix', fontsize=30)

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="4%", pad=0.2)
cb = fig.colorbar(cm, cax=cax)
cb.ax.tick_params(labelsize=15)

plt.tight_layout()

plt.savefig(f'{save_path}/confu_mat_3-1.png', dpi=300)
plt.show()


两种 colorbar

  • coolwarm_r
    在这里插入图片描述

  • GnBu
    在这里插入图片描述

更多 colorbar:https://astromsshin.github.io/science/code/matplotlib_cm/index.html
在这里插入图片描述


创作不易,麻烦点点赞和关注咯!

要使用Python绘制混淆矩阵并更改文字的字号,可以利用`sklearn`库处理混淆矩阵数据以及`matplotlib`和`seaborn`库进行可视化。以下是详细的步骤: ### 步骤一:安装必要的库 首先,确保您已经安装了 `scikit-learn`, `numpy`, 和 `matplotlib` 库。如果尚未安装,可通过以下命令安装: ```bash pip install scikit-learn matplotlib numpy seaborn ``` ### 步骤二:导入所需的库和函数 接下来,在Python脚本中导入需要的库: ```python import numpy as np from sklearn.metrics import confusion_matrix from matplotlib import pyplot as plt import seaborn as sns ``` ### 步骤三:创建或加载数据集 假设我们正在处理的是分类模型的数据。下面是一个简单的示例数据: ```python # 示例标签预测值和真实值 y_true = [2, 0, 2, 2, 0, 1] y_pred = [0, 0, 2, 2, 0, 2] # 计算混淆矩阵 cm = confusion_matrix(y_true, y_pred) print(cm) ``` ### 步骤四:绘制混淆矩阵 现在我们将绘制这个混淆矩阵,并设置文字的字号: ```python plt.figure(figsize=(8,6)) sns.heatmap(cm, annot=True, fmt="d", cmap='Blues', cbar=False) # 设置文本字体大小 for text in ax.texts: text.set_fontsize(14) # 添加标题和坐标轴标签 ax.set_title('Confusion Matrix', fontsize=16) ax.set_xlabel('Predicted label', fontsize=14) ax.set_ylabel('True label', fontsize=14) plt.show() ``` ### 相关问题: 1. **如何调整混淆矩阵的颜色地图**? - 您可以通过修改`sns.heatmap()`中的`cmap`参数来调整颜色映射。 2. **如何从CSV文件加载混淆矩阵数据**? - 使用pandas读取CSV文件,然后计算混淆矩阵。 3. **如何将混淆矩阵应用于实际的机器学习项目中**? -混淆矩阵用于评估分类模型的表现,分析模型对不同类别的识别能力。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值