混淆矩阵Confusion Matrix在深度学习中的使用及其实现

        Confusion Matrix是一种可视化的工具,x轴代表的是预测种类,y轴代表的是真实种类,对应的二维坐标点是当某件物品的真实种类是y时,被认成x的个数,对角线上的数字表明对种类做出了正确的判断,其它的地方表明我们的model混淆了种类。通过混淆矩阵我们可以清楚的了解在对于某个类别,它预测错误的偏向,这有助于给我们提供改进模型的方向
        以下代码使用sklearn的confusion_matrix内置函数计算混淆矩阵,也可以自定义函数取实现,使用matplotlib可视化。

# coding:utf-8
from sklearn.utils.multiclass import unique_labels
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np

def plot_confusion_matrix(y_true, y_pred, classes,
                          normalize=False,
                          title=None,
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    # 如果没有传入title参数,则根据以下逻辑生成title
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'

    # 通过klearn.metrics包中的函数计算混淆矩阵
    cm = confusion_matrix(y_true, y_pred)

    '''
    unique_labels(列表1,列表2...) 将列表们的值去交集、去重,按最后结果按从左到右的出现顺序排列
    unique_labels([1, 2, 10], [5, 11])
    array([ 1,  2,  5, 10, 11])
    '''

    # 更新列表,仅使用数据中显示的标签
    classes = classes[unique_labels(y_true, y_pred)]
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        pass
        print('Confusion matrix, without normalization')

    print('confusion_matrix:\n')
    print(cm)
    '''
    Axes.imshow(self, X, cmap=None, norm=None, aspect=None, interpolation=None, alpha=None, vmin=None, vmax=None, origin=None, extent=None, shape=, filternorm=1, filterrad=4.0, imlim=, resample=None, url=None, *, data=None, **kwargs)
    参数X表示图像的数据
    渐变色 cmap 取值参照https://matplotlib.org/stable/tutorials/colors/colormaps.html
    透明度 alpha 0-1
    aspect用于指定热图的单元格的大小
    interpolation 控制热图的颜色显示形式,是否平滑 常用nearest/lanczos
    '''
    fig, ax = plt.subplots()

    im = ax.imshow(X=cm, interpolation='lanczos', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           # ... and label them with the respective list entries
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')

    ax.set_ylim(len(classes)-0.5, -0.5)

    # 旋转刻度标签并设置其对齐方式
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # 循环数据标注并创建文本注释
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    plt.show()
    return ax

# y_test为真实label,y_pred为预测label,classes为类别名称,是个ndarray数组,内容为string类型的标签
y_test = [1, 0, 2, 1, 0, 1, 1]
y_pred = [0, 0, 2, 1, 0, 1, 2]
class_names = np.array(["cat", "dog", "pig"]) #按你的实际需要修改名称
plot_confusion_matrix(y_test, y_pred, classes=class_names, normalize=False)

在这里插入图片描述

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

TuringQi

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值