在使用mmsegmentation框架时,说方便真的很方便,但是有时候又因为版本不兼容原因,一些在特定版本能使用的代码,在另一个版本中却疯狂报错。例如新老版本对模型初始化的代码包 mmseg.apis 目录下的 inference_model, init_model, show_result_pyplot这几个文件老是出错。
在新的版本中,官方给出的绘制混淆矩阵的代码并不可用,这一点在mmsegmentation中的issue中已经被确认。笔者对代码进行改进,使其能美观地绘制混淆矩阵。代码如下:
import os
import numpy as np
import cv2
from mmseg.apis import init_model, inference_model
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import itertools
def map_label(label_value):
# 定义标签映射字典,笔者这里是六个类别,但是数据集本身的标签值并不是0-5,所以我这里做了一个映射,如果大家不匹配也可以做映射;如果本身匹配就可以0映射为0,1映射为1依次类推
label_mapping = {
0: 0,
14: 4,
38: 1,
52: 5,
113: 3,
75: 2
# 如果还有其他标签值,可以继续添加映射关系
}
# 返回映射后的类别值,如果不在字典中,则返回-1或其他默认值
return label_mapping.get(label_value, -1)
def compute_confusion_matrix(model, img_dir, label_dir):
all_true_labels = []
all_pred_masks = []
img_names = os.listdir(img_dir)
img_paths = [os.path.join(img_dir, img_name) for img_name in img_names]
label_paths = [os.path.join(label_dir, img_name.replace('.jpg', '.png')) for img_name in img_names]
for img_path, label_path in zip(img_paths, label_paths):
img_bgr = cv2.imread(img_path)
result = inference_model(model, img_bgr)
pred_mask = result.pred_sem_seg.data[0].detach().cpu().numpy()
label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) # 以灰度形式读取标签图像
# 将标签值映射到0-5之间的整数
label_mapped = np.vectorize(map_label)(label)
label_mapped = np.where(label_mapped == -1, 5, label_mapped) # 将未知类别映射为5
all_true_labels.append(label_mapped.flatten())
all_pred_masks.append(pred_mask.flatten())
all_true_labels = np.concatenate(all_true_labels)
all_pred_masks = np.concatenate(all_pred_masks)
return confusion_matrix(all_true_labels, all_pred_masks)
def cnf_matrix_plotter(cm, classes, cmap=plt.cm.Blues):
plt.figure(figsize=(10,10))
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # Convert to percentage
plt.imshow(cm_percent, interpolation='nearest', cmap=cmap)
tick_marks = np.arange(len(classes))
plt.title('Confusion Matrix', fontsize=30)
plt.xlabel('Pred', fontsize=25, c='r')
plt.ylabel('True', fontsize=25, c='r')
plt.tick_params(labelsize=16) # 设置类别文字大小
plt.xticks(tick_marks, classes, rotation=90) # 横轴文字旋转
plt.yticks(tick_marks, classes)
# 写数宇
threshold = cm.max()/2
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, "{:.2f}".format(cm_percent[i, j]),
horizontalalignment="center",
color="white" if cm_percent[i, j] > threshold else "black",
fontsize=12)
plt.tight_layout()
plt.savefig('confusion_matrix')
#配置文件路径
config_file = r"./my_moudel/deeplabv3plus/deeplabv3plus_r50-d8_4xb4-40k_voc12aug-512x512.py"
#权重路径
checkpointfile = r'./work_dirs/rccnet_cbam_cat/best_mIoU_iter_9550.pth'
model = init_model(config_file, checkpointfile, device='cuda:0')
img_dir = r"./img_predict/imgs"
label_dir = r"./img_predict/labels"
confusion_matrix_model = compute_confusion_matrix(model, img_dir, label_dir)
#横纵坐标名字
classes = ['0', '1', '2', '3', '4', '5']
cnf_matrix_plotter(confusion_matrix_model, classes, cmap='Greens')