在实际做项目的时候,我们想找出一些错例、难例,对这些错例难例进行分析,使模型达到更好的效果。下面是在目标检测任务和分类任务进行badcase可视化的方法。
利用目标检测混淆矩阵写badcase可视化(Ultralytics版)
def process_batch(self, detections, labels):
"""
Update confusion matrix for object detection task.
Args:
detections (Array[N, 6]): Detected bounding boxes and their associated information.
Each row should contain (x1, y1, x2, y2, conf, class).
labels (Array[M, 5]): Ground truth bounding boxes and their associated class labels.
Each row should contain (class, x1, y1, x2, y2).
"""
if labels.size(0) == 0: # Check if labels is empty
if detections is not None:
detections = detections[detections[:, 4] > self.conf]
detection_classes = detections[:, 5].int()
for dc in detection_classes:
self.matrix[dc, self.nc] += 1 # false positives
return
if detections is None:
gt_classes = labels.int()
for gc in gt_classes:
self.matrix[self.nc, gc] += 1 # background FN
return
detections = detections[detections[:, 4] > self.conf]
gt_classes = labels[:, 0].int()
detection_classes = detections[:, 5].int()
iou = box_iou(labels[:, 1:], detections[:, :4])
x = torch.where(iou > self.iou_thres)
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
else:
matches = np.zeros((0, 3))
n = matches.shape[0] > 0
m0, m1, _ = matches.transpose().astype(int)
for i, gc in enumerate(gt_classes):
j = m0 == i
if n and sum(j) == 1:
self.matrix[detection_classes[m1[j]], gc] += 1 # correct
else:
self.matrix[self.nc, gc] += 1 # true background
if n:
for i, dc in enumerate(detection_classes):
if not any(m1 == i):
self.matrix[dc, self.nc] += 1 # predicted background
val的过程中调用了ConfusionMatrix的process_batch方法,用于更新混淆矩阵。
self.confusion_matrix.process_batch(predn, labelsn)
我对混淆矩阵代码进行了修改,在函数最后返回了fp,fn的索引,修改的代码片段如下:
def process_batch(self, detections, gt_bboxes, gt_cls):
"""
Update confusion matrix for object detection task.
Args:
detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information.
Each row should contain (x1, y1, x2, y2, conf, class)
or with an additional element `angle` when it's obb.
gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format.
gt_cls (Array[M]): The class labels.
"""
fp = []
fn = []
if gt_cls.shape[0] == 0: # Check if labels is empty
if detections is not None:
detections = detections[detections[:, 4] > self.conf]
detection_classes = detections[:, 5].int()
for dc in detection_classes:
self.matrix[dc, self.nc] += 1 # false positives
fp.append('fp')
return fp, fn
if detections is None:
gt_classes = gt_cls.int()
for gc in gt_classes:
self.matrix[self.nc, gc] += 1 # background FN
fn.append('fn')
return fp, fn
detections = detections[detections[:, 4] > self.conf]
gt_classes = gt_cls.int()
detection_classes = detections[:, 5].int()
is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5 # with additional `angle` dimension
iou = (
batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
if is_obb
else box_iou(gt_bboxes, detections[:, :4])
)
x = torch.where(iou > self.iou_thres)
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
else:
matches = np.zeros((0, 3))
n = matches.shape[0] > 0
m0, m1, _ = matches.transpose().astype(int)
for i, gc in enumerate(gt_classes):
j = m0 == i
if n and sum(j) == 1:
self.matrix[detection_classes[m1[j]], gc] += 1 # correct
if detection_classes[m1[j]] != gc:
fp.append(m1[j])
else:
self.matrix[self.nc, gc] += 1 # true background
fn.append(i)
if n:
for i, dc in enumerate(detection_classes):
if not any(m1 == i):
self.matrix[dc, self.nc] += 1 # predicted background
fp.append(i)
return fp, fn
在调用此方法时获取fp、fn的索引 :
fp, fn = self.confusion_matrix.process_batch(predn, bbox, cls)
然后按照自己的喜好去写badcase的可视化,以下是我修改的一个示例:
fp, fn = self.confusion_matrix.process_batch(predn, bbox, cls)
# gyh vis_badcase start
task = f'{self}'.split('.')[3] # 这里在判断当前的任务是否是detection
mode = self.args.mode
if task == 'detect' and mode == 'val' and self.args.vis_badcase: # 这里我传入了一个标志位,如果在调用model.val方法时将vis_badcase设为True,才会执行下面的内容
if fp or fn:
# gt 这是对真值进行处理
from torchvision.transforms import ToPILImage
import numpy as np
img_tensor = batch['img'].cpu().squeeze(0)
gt_cls = batch['cls'].cpu().squeeze().numpy()
gt_bboxes = batch['bboxes'].cpu().numpy()
to_pil = ToPILImage()
ori_img = to_pil(img_tensor)
gt_img = np.asarray(ori_img)
gt_img = np.copy(gt_img)
if gt_cls.size > 1:
for i in range(len(gt_cls)):
x, y, w, h = gt_bboxes[i]
x, y, w, h = x * gt_img.shape[1], y * gt_img.shape[0], w * gt_img.shape[1], h * gt_img.shape[0]
x1, y1, x2, y2 = x - w / 2, y - h / 2, x + w / 2, y + h / 2
label = self.names[gt_cls[i]]
if any(isinstance(item, str) for item in fn):
color = (0, 255, 0)
else:
if i in fn:
color = (0, 0, 255) # 如果有漏检,则将这个目标框设为红色
else:
color = (0, 255, 0) # 如果没有漏检,就是绿色目标框
cv2.rectangle(gt_img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
cv2.putText(gt_img, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
else:
x, y, w, h = gt_bboxes[0]
x, y, w, h = x * gt_img.shape[1], y * gt_img.shape[0], w * gt_img.shape[1], h * gt_img.shape[0]
x1, y1, x2, y2 = x - w / 2, y - h / 2, x + w / 2, y + h / 2
label = self.names[int(gt_cls.item())]
if fn:
if any(isinstance(item, str) for item in fn):
pass
else:
color = (0, 0, 255)
else:
color = (0, 255, 0)
cv2.rectangle(gt_img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
cv2.putText(gt_img, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
# pred 这里对预测值进行处理
pred_img = np.asarray(ori_img)
pred_img = np.copy(pred_img)
pred_cls = preds[0][:, 5]
pred_bboxes = preds[0][:, :4]
pred_conf = preds[0][:, 4]
for i in range(len(pred_cls)):
x1, y1, x2, y2 = pred_bboxes[i]
label = self.names[pred_cls[i].tolist()] + "{:.2f}".format(pred_conf[i].tolist())
if any(isinstance(item, str) for item in fp):
color = (0, 255, 0) # green
else:
if i in fp:
color = (0, 0, 255) # red 如果有错检,则画红色的框框
else:
color = (0, 255, 0)
cv2.rectangle(pred_img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
cv2.putText(pred_img, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
img_name = batch['im_file'][0].split('\\')[-1]
output = f'{self.args.vis_badcase_path}\\{img_name}'
if not os.path.exists(self.args.vis_badcase_path):
os.makedirs(self.args.vis_badcase_path)
merge_images(gt_img, pred_img, output) # 这是一个简单的图像拼接函数,为了方便真值与预测值进行对比
# gyh vis_badcase end
可能看起来有点复杂,但是在针对目标很多的图像时(例如培养皿中的细胞颗粒),这个脚本很好使!但由于保密问题,我不能展示细胞图片,但我可以展示这个脚本呈现的效果:
左边是真值,右边是预测值,左图的红色框框就是漏检的object。我们再看一张图:
右边的红色个框框就是错检的object,但是这里其实预测没有错,是标注错了,所以这个脚本还能检查标注的错例。
分类任务的badcase可视化(Ultralytics版)
def process_cls_preds(self, preds, targets):
"""
Update confusion matrix for classification task.
Args:
preds (Array[N, min(nc,5)]): Predicted class labels.
targets (Array[N, 1]): Ground truth class labels.
"""
preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
self.matrix[p][t] += 1
分类任务就简单多了,甚至我好像没有借助到混淆矩阵这个方法,我直接在后处理之后,将真值与预测值的类别进行对比,如果类别不一致就将这个错例保存下来,具体代码如下:
# gyh 0709 plot badcase
task = f'{self}'.split('.')[3]
if task == 'classify':
if self.args.vis_cls_badcase:
import os
for cls in self.names:
base_name = self.names[cls]
path = os.path.join(self.args.vis_cls_badcase_path, base_name)
os.makedirs(path, exist_ok=True)
pred_v, pred_cls = preds.max(dim=1)
if pred_cls != batch['cls']:
# get ori_img
image_tensor = batch['img'].cpu().squeeze(0)
from torchvision.transforms import ToPILImage
to_pil = ToPILImage()
pil_image = to_pil(image_tensor)
image_np = np.asarray(pil_image)
# get info
label = self.names[batch['cls'].tolist()[0]]
pred = self.names[pred_cls.tolist()[0]]
text = f"{label}, {pred}"
# save
save_path = f'{self.args.vis_cls_badcase_path}/{label}'
save_name = save_path + f'/{idx}.jpg'
idx += 1
# plot
import cv2
cv2.putText(image_np, text, (5, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
cv2.imwrite(save_name, image_np)
# gyh 0709 plot badcase
这个就不展示效果啦,如果目标检测任务的badcase能弄明白,分类任务就很简单啦。
最近在做一些实例分割的项目,mask的badcase还在研究中...