MMSegmentation绘制混淆矩阵

在使用mmsegmentation框架时,说方便真的很方便,但是有时候又因为版本不兼容原因,一些在特定版本能使用的代码,在另一个版本中却疯狂报错。例如新老版本对模型初始化的代码包 mmseg.apis 目录下的 inference_model, init_model, show_result_pyplot这几个文件老是出错。

在新的版本中,官方给出的绘制混淆矩阵的代码并不可用,这一点在mmsegmentation中的issue中已经被确认。笔者对代码进行改进,使其能美观地绘制混淆矩阵。代码如下:

import os
import numpy as np
import cv2
from mmseg.apis import init_model, inference_model
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import itertools

def map_label(label_value):
    # 定义标签映射字典,笔者这里是六个类别,但是数据集本身的标签值并不是0-5,所以我这里做了一个映射,如果大家不匹配也可以做映射;如果本身匹配就可以0映射为0,1映射为1依次类推
    label_mapping = {
        0: 0,
        14: 4,
        38: 1,
        52: 5,
        113: 3,
        75: 2
        # 如果还有其他标签值,可以继续添加映射关系
    }
    
    # 返回映射后的类别值,如果不在字典中,则返回-1或其他默认值
    return label_mapping.get(label_value, -1)

def compute_confusion_matrix(model, img_dir, label_dir):
    all_true_labels = []
    all_pred_masks = []

    img_names = os.listdir(img_dir)
    img_paths = [os.path.join(img_dir, img_name) for img_name in img_names]
    label_paths = [os.path.join(label_dir, img_name.replace('.jpg', '.png')) for img_name in img_names]

    for img_path, label_path in zip(img_paths, label_paths):
        img_bgr = cv2.imread(img_path)
        result = inference_model(model, img_bgr)
        pred_mask = result.pred_sem_seg.data[0].detach().cpu().numpy()
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)  # 以灰度形式读取标签图像
        # 将标签值映射到0-5之间的整数
        label_mapped = np.vectorize(map_label)(label)
        label_mapped = np.where(label_mapped == -1, 5, label_mapped)  # 将未知类别映射为5

        all_true_labels.append(label_mapped.flatten())
        all_pred_masks.append(pred_mask.flatten())

    all_true_labels = np.concatenate(all_true_labels)
    all_pred_masks = np.concatenate(all_pred_masks)

    return confusion_matrix(all_true_labels, all_pred_masks)

def cnf_matrix_plotter(cm, classes, cmap=plt.cm.Blues):
    plt.figure(figsize=(10,10))

    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]  # Convert to percentage

    plt.imshow(cm_percent, interpolation='nearest', cmap=cmap)
    tick_marks = np.arange(len(classes))
    plt.title('Confusion Matrix', fontsize=30)
    plt.xlabel('Pred', fontsize=25, c='r')
    plt.ylabel('True', fontsize=25, c='r')
    plt.tick_params(labelsize=16)  # 设置类别文字大小
    plt.xticks(tick_marks, classes, rotation=90)  # 横轴文字旋转
    plt.yticks(tick_marks, classes)
    # 写数宇
    threshold = cm.max()/2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, "{:.2f}".format(cm_percent[i, j]),
                 horizontalalignment="center",
                 color="white" if cm_percent[i, j] > threshold else "black",
                 fontsize=12)
    plt.tight_layout()
    plt.savefig('confusion_matrix')

#配置文件路径
config_file = r"./my_moudel/deeplabv3plus/deeplabv3plus_r50-d8_4xb4-40k_voc12aug-512x512.py"
#权重路径
checkpointfile = r'./work_dirs/rccnet_cbam_cat/best_mIoU_iter_9550.pth'
model = init_model(config_file, checkpointfile, device='cuda:0')

img_dir = r"./img_predict/imgs"
label_dir = r"./img_predict/labels"
confusion_matrix_model = compute_confusion_matrix(model, img_dir, label_dir)
#横纵坐标名字
classes = ['0', '1', '2', '3', '4', '5']

cnf_matrix_plotter(confusion_matrix_model, classes, cmap='Greens')

  • 16
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值