Pytorch+Tensorboard混淆矩阵可视化

引言

混淆矩阵是分类任务常用的一种评估方法。对角线元素表示预测标签等于真实标签的点数,而非对角线元素则是分类器未正确标记的点的数量。 混淆矩阵的对角线值越高越好,表明有许多正确的预测。1

尤其是在类别数量不平衡的情况下,相比accuracy,混淆矩阵(confusion matrix)对哪个类被错误分类具有更直观的解释

在平时做简单的数据实验时,可以仅用from sklearn.metrics import plot_confusion_matrix或者seaborn对混淆矩阵进行可视化。但是在深度学习训练模型的过程中,在tensorboard中可视化混淆矩阵会更方便结果记录和对照。

混淆矩阵

在tensorboard中的可视化效果:
在这里插入图片描述

代码实现

代码参考facebook的SlowFast工程2

引用库

import itertools
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import confusion_matrix

计算混淆矩阵

从pytorch模型输出的预测结果preds、真值labels,计算混淆矩阵。

def get_confusion_matrix(preds, labels, num_classes, normalize="true"):
    """
    Calculate confusion matrix on the provided preds and labels.
    Args:
        preds (tensor or lists of tensors): predictions. Each tensor is in
            in the shape of (n_batch, num_classes). Tensor(s) must be on CPU.
        labels (tensor or lists of tensors): corresponding labels. Each tensor is
            in the shape of either (n_batch,) or (n_batch, num_classes).
        num_classes (int): number of classes. Tensor(s) must be on CPU.
        normalize (Optional[str]) : {‘true’, ‘pred’, ‘all’}, default="true"
            Normalizes confusion matrix over the true (rows), predicted (columns)
            conditions or all the population. If None, confusion matrix
            will not be normalized.
    Returns:
        cmtx (ndarray): confusion matrix of size (num_classes x num_classes)
    """
    if isinstance(preds, list):
        preds = torch.cat(preds, dim=0)
    if isinstance(labels, list):
        labels = torch.cat(labels, dim=0)
    # If labels are one-hot encoded, get their indices.
    if labels.ndim == preds.ndim:
        labels = torch.argmax(labels, dim=-1)
    # Get the predicted class indices for examples.
    preds = torch.flatten(torch.argmax(preds, dim=-1))
    labels = torch.flatten(labels)
    cmtx = confusion_matrix(
        labels, preds, labels=list(range(num_classes)))#, normalize=normalize) 部分版本无该参数
    return cmtx

绘制混淆矩阵

输入get_confusion_matrix获取的混淆矩阵cmtx,类别数量和类别名称,进行混淆矩阵绘制。

def plot_confusion_matrix(cmtx, num_classes, class_names=None, figsize=None):
    """
    A function to create a colored and labeled confusion matrix matplotlib figure
    given true labels and preds.
    Args:
        cmtx (ndarray): confusion matrix.
        num_classes (int): total number of classes.
        class_names (Optional[list of strs]): a list of class names.
        figsize (Optional[float, float]): the figure size of the confusion matrix.
            If None, default to [6.4, 4.8].

    Returns:
        img (figure): matplotlib figure.
    """
    if class_names is None or type(class_names) != list:
        class_names = [str(i) for i in range(num_classes)]

    figure = plt.figure(figsize=figsize)
    plt.imshow(cmtx, interpolation="nearest", cmap=plt.cm.Blues)
    plt.title("Confusion matrix")
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)

    # Use white text if squares are dark; otherwise black.
    threshold = cmtx.max() / 2.0
    for i, j in itertools.product(range(cmtx.shape[0]), range(cmtx.shape[1])):
        color = "white" if cmtx[i, j] > threshold else "black"
        plt.text(
            j,
            i,
            format(cmtx[i, j], ".2f") if cmtx[i, j] != 0 else ".",
            horizontalalignment="center",
            color=color,
        )

    plt.tight_layout()
    plt.ylabel("True label")
    plt.xlabel("Predicted label")

    return figure

在tensorboard中添加混淆矩阵

将plot_confusion_matrix返回的绘制图像显示在tensorboard中。

from torch.utils.tensorboard import SummaryWriter

def add_confusion_matrix(
    writer,
    cmtx,
    num_classes,
    global_step=None,
    subset_ids=None,
    class_names=None,
    tag="Confusion Matrix",
    figsize=None,
):
    """
    Calculate and plot confusion matrix to a SummaryWriter.
    Args:
        writer (SummaryWriter): the SummaryWriter to write the matrix to.
        cmtx (ndarray): confusion matrix.
        num_classes (int): total number of classes.
        global_step (Optional[int]): current step.
        subset_ids (list of ints): a list of label indices to keep.
        class_names (list of strs, optional): a list of all class names.
        tag (str or list of strs): name(s) of the confusion matrix image.
        figsize (Optional[float, float]): the figure size of the confusion matrix.
            If None, default to [6.4, 4.8].

    """
    if subset_ids is None or len(subset_ids) != 0:
        # If class names are not provided, use class indices as class names.
        if class_names is None:
            class_names = [str(i) for i in range(num_classes)]
        # If subset is not provided, take every classes.
        if subset_ids is None:
            subset_ids = list(range(num_classes))

        sub_cmtx = cmtx[subset_ids, :][:, subset_ids]
        sub_names = [class_names[j] for j in subset_ids]

        sub_cmtx = plot_confusion_matrix(
            sub_cmtx,
            num_classes=len(subset_ids),
            class_names=sub_names,
            figsize=figsize,
        )
        # Add the confusion matrix image to writer.
        writer.add_figure(tag=tag, figure=sub_cmtx, global_step=global_step)

在训练过程中绘制混淆矩阵

    model.train()
    
    # 预测值和标注值,用于绘制混淆矩阵
    preds=[]
    labels=[]
    
    for i, (inputs, targets) in enumerate(data_loader):
        targets = targets.to(device, non_blocking=True)#shape: (n_batch,)

        try:
            outputs = model(inputs)#shape: (n_batch,n_classes)
            loss = criterion(outputs, targets)            
            
            # 需将tensor从gpu转到cpu上
            preds.append(outputs.cpu())
            labels.append(targets.cpu())

            acc,recall = calculate_precision_and_recall(outputs, targets,pos_label=0)
    
            losses.update(float(loss.item()), inputs.size(0))
            accuracies.update(float(acc), inputs.size(0))
            recalls.update(float(recall), inputs.size(0))
            
            #total_loss+=float(loss.item())
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
        """
        混淆矩阵可视化
        """
        preds = torch.cat(preds,dim=0)
        labels = torch.cat(labels,dim=0)
        cmtx = get_confusion_matrix(preds,labels,len(class_names))
        add_confusion_matrix(tb_writer,cmtx,num_classes=len(class_names),class_names=class_names,tag="Train Confusion Matrix",figsize=[10,8])
        
    

  1. https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html ↩︎

  2. https://github.com/facebookresearch/SlowFast/tree/master/slowfast/visualization ↩︎

  • 9
    点赞
  • 54
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值