分类模型confusion matrix混淆矩阵可视化

        之前写过一篇关于在scikit-learn工具包中,可视化estimator分类模型分类结果的confusion matrix混淆矩阵可视化的方法,具体可以参考看这里,看这里。今天这篇介绍一下如何使用scikit-learn工具中提供的相关方法,可视化其他任意框架(比如深度学习框架)的分类模型预测结果的混淆矩阵。

        下面先说一下几个关键步骤:

1、确定类别列表,类别列表和one-hot的编码顺序一致,这里使用cifar-10的类别列表作为演示的例子。

classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"

2、准备好样本的真实label,这里我手动构造一个1000个样本的label,每一类100个。

# 生成数据集的GT标签
gt_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):
    gt_labels[i] = i
gt_labels = gt_labels.reshape(1, -1).squeeze()
print("gt_labels.shape : {}".format(gt_labels.shape))
print("gt_labels : {}".format(gt_labels[::5]))

3、准备好样本的预测label,这里我也手动构造这1000个样本的预测label,构造时才用了一点规则,构造出来的预测结果保证从第0类到第9类的预测准确率是逐渐降低的。

# 生成数据集的预测标签
pred_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):
    # 标签生成规则:对于真值类别编号为i的数据,生成的预测类别编号为[0, i-1]之间的随机值
    # 这样生成的预测准确率从0到9逐渐递减
    pred_labels[i] = np.random.randint(0, i + 1, 100)
pred_labels = pred_labels.reshape(1, -1).squeeze()
print("pred_labels.shape : {}".format(pred_labels.shape))
print("pred_labels : {}".format(pred_labels[::5]))

4、计算真是label和预测label的混淆矩阵,直接调用scikit-learn中的confusion_matrix方法

# 使用sklearn工具中confusion_matrix方法计算混淆矩阵
confusion_mat = confusion_matrix(gt_labels, pred_labels)
print("confusion_mat.shape : {}".format(confusion_mat.shape))
print("confusion_mat : {}".format(confusion_mat))

5、混淆矩阵可视化,在scikit-learn工具中有一个plot_confusion_matrix方法可以可视化sklearn训练的模型estimator的混淆矩阵,具体参数如下:

        但是,现在的问题是我们使用的是别的框架训练的模型,也就没有这个estimator参数可以供sklearn使用,怎么办?

        我们看一下plot_confusion_matrix函数的代码可以发现,他其实内部调用了以下方法:

         那么,我们也仿照这个调用方式来写一下试试,代码如下:

# 使用sklearn工具包中的ConfusionMatrixDisplay可视化混淆矩阵,参考plot_confusion_matrix
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_mat, display_labels=classes)
disp.plot(
    include_values=True,            # 混淆矩阵每个单元格上显示具体数值
    cmap="viridis",                 # 不清楚啥意思,没研究,使用的sklearn中的默认值
    ax=None,                        # 同上
    xticks_rotation="horizontal",   # 同上
    values_format="d"               # 显示的数值格式
)

 6、将以上代码整合一下,输入数据的真实label和预测label,就可以可视化混淆矩阵了,并且不仅局限于评估scikit-learn的estimator,可以适用于所有框架的输出结果,完整代码如下:

import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from matplotlib import pyplot as plt

classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

# 生成数据集的GT标签
gt_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):
    gt_labels[i] = i
gt_labels = gt_labels.reshape(1, -1).squeeze()
print("gt_labels.shape : {}".format(gt_labels.shape))
print("gt_labels : {}".format(gt_labels[::5]))

# 生成数据集的预测标签
pred_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):
    # 标签生成规则:对于真值类别编号为i的数据,生成的预测类别编号为[0, i-1]之间的随机值
    # 这样生成的预测准确率从0到9逐渐递减
    pred_labels[i] = np.random.randint(0, i + 1, 100)
pred_labels = pred_labels.reshape(1, -1).squeeze()
print("pred_labels.shape : {}".format(pred_labels.shape))
print("pred_labels : {}".format(pred_labels[::5]))

# 使用sklearn工具中confusion_matrix方法计算混淆矩阵
confusion_mat = confusion_matrix(gt_labels, pred_labels)
print("confusion_mat.shape : {}".format(confusion_mat.shape))
print("confusion_mat : {}".format(confusion_mat))

# 使用sklearn工具包中的ConfusionMatrixDisplay可视化混淆矩阵,参考plot_confusion_matrix
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_mat, display_labels=classes)
disp.plot(
    include_values=True,            # 混淆矩阵每个单元格上显示具体数值
    cmap="viridis",                 # 不清楚啥意思,没研究,使用的sklearn中的默认值
    ax=None,                        # 同上
    xticks_rotation="horizontal",   # 同上
    values_format="d"               # 显示的数值格式
)
plt.show()

7、混淆矩阵的可视化结果

        上图中的可视化结果符合我们在生成预测label标签时使用的规则,就是对于每个类别 i 的预测结果是0-i之间的随机值,这样的话,每个类别的预测误差只会出现在类别编号比它小的部分,也就是上图中展示的下三角矩阵。

        在混淆矩阵中,横轴上的标签标示样本的预测label,纵轴上的标签标示样本的实际label。所以,对角线上的数字表示预测label和真是label一致的数量,也就是预测正确的数量。对于其他位置的数字就表示预测错误的,举个例子,比如第2行、第1列,也就是对应着(airplane, automobile)位置的数字51,表示有51个真实label为automobile的样本被预测为了airplane。

        通过可视化的混淆矩阵,模型的误差,以及效果分类不好的类别,以及为什么不好,以及容易和哪个类之间出现误识别就一目了然了。

参考:https://blog.csdn.net/cxx654/article/details/107296343

  • 14
    点赞
  • 68
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
您可以使用R语言中的`caret`包来计算二分类逻辑回归模型的混淆矩阵,并使用`ggplot2`包来进行可视化。以下是一个示例代码: ```R # 导入所需的包 library(caret) library(ggplot2) # 创建一个示例数据集 data <- data.frame(X1 = rnorm(100), X2 = rnorm(100), Y = factor(sample(c(0, 1), 100, replace = TRUE))) # 拆分数据集为训练集和测试集 set.seed(123) trainIndex <- createDataPartition(data$Y, p = 0.7, list = FALSE) trainData <- data[trainIndex, ] testData <- data[-trainIndex, ] # 训练逻辑回归模型 model <- train(Y ~ ., data = trainData, method = "glm", family = "binomial") # 在测试集上进行预测 predictions <- predict(model, newdata = testData) # 创建混淆矩阵 confusionMatrix <- caret::confusionMatrix(predictions, testData$Y) # 可视化混淆矩阵 ggplot(confusionMatrix$table, aes(Prediction, Reference, fill = Freq)) + geom_tile() + geom_text(aes(label = Freq), color = "white") + theme_minimal() + scale_fill_gradient(low = "white", high = "steelblue") + labs(title = "Confusion Matrix", x = "Prediction", y = "Reference") ``` 这段代码首先导入了`caret`和`ggplot2`包,然后创建了一个示例数据集。接下来,使用`createDataPartition`函数将数据集拆分为训练集和测试集。然后,使用`train`函数训练逻辑回归模型,并使用`predict`函数在测试集上进行预测。随后,使用`confusionMatrix`函数计算混淆矩阵。最后,使用`ggplot`函数将混淆矩阵可视化为热力图。 请注意,这只是一个示例代码,您需要根据自己的数据和需求进行适当的修改。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值