利用混淆矩阵写badcase可视化

在实际做项目的时候,我们想找出一些错例、难例,对这些错例难例进行分析,使模型达到更好的效果。下面是在目标检测任务和分类任务进行badcase可视化的方法。

利用目标检测混淆矩阵写badcase可视化(Ultralytics版)

def process_batch(self, detections, labels):
    """
    Update confusion matrix for object detection task.

    Args:
        detections (Array[N, 6]): Detected bounding boxes and their associated information.
                                    Each row should contain (x1, y1, x2, y2, conf, class).
        labels (Array[M, 5]): Ground truth bounding boxes and their associated class labels.
                                Each row should contain (class, x1, y1, x2, y2).
    """
    if labels.size(0) == 0:  # Check if labels is empty
        if detections is not None:
            detections = detections[detections[:, 4] > self.conf]
            detection_classes = detections[:, 5].int()
            for dc in detection_classes:
                self.matrix[dc, self.nc] += 1  # false positives
        return
    if detections is None:
        gt_classes = labels.int()
        for gc in gt_classes:
            self.matrix[self.nc, gc] += 1  # background FN
        return

    detections = detections[detections[:, 4] > self.conf]
    gt_classes = labels[:, 0].int()
    detection_classes = detections[:, 5].int()
    iou = box_iou(labels[:, 1:], detections[:, :4])

    x = torch.where(iou > self.iou_thres)
    if x[0].shape[0]:
        matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
        if x[0].shape[0] > 1:
            matches = matches[matches[:, 2].argsort()[::-1]]
            matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
            matches = matches[matches[:, 2].argsort()[::-1]]
            matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
    else:
        matches = np.zeros((0, 3))

    n = matches.shape[0] > 0
    m0, m1, _ = matches.transpose().astype(int)
    for i, gc in enumerate(gt_classes):
        j = m0 == i
        if n and sum(j) == 1:
            self.matrix[detection_classes[m1[j]], gc] += 1  # correct
        else:
            self.matrix[self.nc, gc] += 1  # true background

    if n:
        for i, dc in enumerate(detection_classes):
            if not any(m1 == i):
                self.matrix[dc, self.nc] += 1  # predicted background

val的过程中调用了ConfusionMatrix的process_batch方法,用于更新混淆矩阵。

self.confusion_matrix.process_batch(predn, labelsn)

我对混淆矩阵代码进行了修改,在函数最后返回了fp,fn的索引,修改的代码片段如下:

def process_batch(self, detections, gt_bboxes, gt_cls):
    """
    Update confusion matrix for object detection task.

    Args:
        detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information.
                                    Each row should contain (x1, y1, x2, y2, conf, class)
                                    or with an additional element `angle` when it's obb.
        gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format.
        gt_cls (Array[M]): The class labels.
    """
    fp = []
    fn = []

    if gt_cls.shape[0] == 0:  # Check if labels is empty
        if detections is not None:
            detections = detections[detections[:, 4] > self.conf]
            detection_classes = detections[:, 5].int()
            for dc in detection_classes:
                self.matrix[dc, self.nc] += 1  # false positives
                fp.append('fp')
        return fp, fn
    if detections is None:
        gt_classes = gt_cls.int()
        for gc in gt_classes:
            self.matrix[self.nc, gc] += 1  # background FN
            fn.append('fn')
        return fp, fn

    detections = detections[detections[:, 4] > self.conf]
    gt_classes = gt_cls.int()
    detection_classes = detections[:, 5].int()
    is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5  # with additional `angle` dimension
    iou = (
        batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
        if is_obb
        else box_iou(gt_bboxes, detections[:, :4])
    )

    x = torch.where(iou > self.iou_thres)
    if x[0].shape[0]:
        matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
        if x[0].shape[0] > 1:
            matches = matches[matches[:, 2].argsort()[::-1]]
            matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
            matches = matches[matches[:, 2].argsort()[::-1]]
            matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
    else:
        matches = np.zeros((0, 3))

    n = matches.shape[0] > 0
    m0, m1, _ = matches.transpose().astype(int)
    for i, gc in enumerate(gt_classes):
        j = m0 == i
        if n and sum(j) == 1:
            self.matrix[detection_classes[m1[j]], gc] += 1  # correct
            if detection_classes[m1[j]] != gc:
                fp.append(m1[j])
        else:
            self.matrix[self.nc, gc] += 1  # true background
            fn.append(i)

    if n:
        for i, dc in enumerate(detection_classes):
            if not any(m1 == i):
                self.matrix[dc, self.nc] += 1  # predicted background
                fp.append(i)

    return fp, fn

在调用此方法时获取fp、fn的索引 :

fp, fn = self.confusion_matrix.process_batch(predn, bbox, cls)

然后按照自己的喜好去写badcase的可视化,以下是我修改的一个示例:

fp, fn = self.confusion_matrix.process_batch(predn, bbox, cls)
# gyh vis_badcase start
task = f'{self}'.split('.')[3] # 这里在判断当前的任务是否是detection
mode = self.args.mode
if task == 'detect' and mode == 'val' and self.args.vis_badcase: # 这里我传入了一个标志位,如果在调用model.val方法时将vis_badcase设为True,才会执行下面的内容
    if fp or fn:
        # gt 这是对真值进行处理
        from torchvision.transforms import ToPILImage
        import numpy as np
        img_tensor = batch['img'].cpu().squeeze(0)
        gt_cls = batch['cls'].cpu().squeeze().numpy()
        gt_bboxes = batch['bboxes'].cpu().numpy()
        
        to_pil = ToPILImage()
        ori_img = to_pil(img_tensor)
        gt_img = np.asarray(ori_img)
        gt_img = np.copy(gt_img)
        
        if gt_cls.size > 1:
            for i in range(len(gt_cls)):
                x, y, w, h = gt_bboxes[i]
                x, y, w, h = x * gt_img.shape[1], y * gt_img.shape[0], w * gt_img.shape[1], h * gt_img.shape[0]
                x1, y1, x2, y2 = x - w / 2, y - h / 2, x + w / 2, y + h / 2
                label = self.names[gt_cls[i]]
                if any(isinstance(item, str) for item in fn):
                    color = (0, 255, 0)
                else:
                    if i in fn:
                        color = (0, 0, 255) # 如果有漏检,则将这个目标框设为红色
                    else:
                        color = (0, 255, 0) # 如果没有漏检,就是绿色目标框
                cv2.rectangle(gt_img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
                cv2.putText(gt_img, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
                
        else:
            x, y, w, h = gt_bboxes[0]
            x, y, w, h = x * gt_img.shape[1], y * gt_img.shape[0], w * gt_img.shape[1], h * gt_img.shape[0]
            x1, y1, x2, y2 = x - w / 2, y - h / 2, x + w / 2, y + h / 2
            label = self.names[int(gt_cls.item())]
            if fn:
                if any(isinstance(item, str) for item in fn):
                    pass
                else:
                    color = (0, 0, 255)
            else:
                color = (0, 255, 0)
            cv2.rectangle(gt_img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
            cv2.putText(gt_img, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)

        # pred 这里对预测值进行处理
        pred_img = np.asarray(ori_img)
        pred_img = np.copy(pred_img)
        pred_cls = preds[0][:, 5]
        pred_bboxes = preds[0][:, :4]
        pred_conf = preds[0][:, 4]
        for i in range(len(pred_cls)):
            x1, y1, x2, y2 = pred_bboxes[i]
            label = self.names[pred_cls[i].tolist()] + "{:.2f}".format(pred_conf[i].tolist())
            if any(isinstance(item, str) for item in fp):
                color = (0, 255, 0)  # green
            else:
                if i in fp:
                    color = (0, 0, 255)  # red 如果有错检,则画红色的框框
                else:
                    color = (0, 255, 0)
            
            cv2.rectangle(pred_img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
            cv2.putText(pred_img, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
        
        
        img_name = batch['im_file'][0].split('\\')[-1]
        output = f'{self.args.vis_badcase_path}\\{img_name}'
        if not os.path.exists(self.args.vis_badcase_path):
            os.makedirs(self.args.vis_badcase_path)
        merge_images(gt_img, pred_img, output) # 这是一个简单的图像拼接函数,为了方便真值与预测值进行对比
# gyh vis_badcase end

可能看起来有点复杂,但是在针对目标很多的图像时(例如培养皿中的细胞颗粒),这个脚本很好使!但由于保密问题,我不能展示细胞图片,但我可以展示这个脚本呈现的效果:

左边是真值,右边是预测值,左图的红色框框就是漏检的object。我们再看一张图:

右边的红色个框框就是错检的object,但是这里其实预测没有错,是标注错了,所以这个脚本还能检查标注的错例。 

分类任务的badcase可视化(Ultralytics版)

    def process_cls_preds(self, preds, targets):
        """
        Update confusion matrix for classification task.

        Args:
            preds (Array[N, min(nc,5)]): Predicted class labels.
            targets (Array[N, 1]): Ground truth class labels.
        """
        preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
        for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
            self.matrix[p][t] += 1

分类任务就简单多了,甚至我好像没有借助到混淆矩阵这个方法,我直接在后处理之后,将真值与预测值的类别进行对比,如果类别不一致就将这个错例保存下来,具体代码如下:
 

# gyh 0709 plot badcase
task = f'{self}'.split('.')[3]
if task == 'classify':
    if self.args.vis_cls_badcase: 
        import os
        for cls in self.names:
            base_name = self.names[cls]
            path = os.path.join(self.args.vis_cls_badcase_path, base_name)
            os.makedirs(path, exist_ok=True)
        pred_v, pred_cls = preds.max(dim=1)
        if pred_cls != batch['cls']:
            # get ori_img
            image_tensor = batch['img'].cpu().squeeze(0)
            from torchvision.transforms import ToPILImage
            to_pil = ToPILImage()
            pil_image = to_pil(image_tensor)
            image_np = np.asarray(pil_image)
            
            # get info
            label = self.names[batch['cls'].tolist()[0]]
            pred = self.names[pred_cls.tolist()[0]]
            text = f"{label}, {pred}"
            
            # save
            save_path = f'{self.args.vis_cls_badcase_path}/{label}'
            save_name = save_path + f'/{idx}.jpg'
            idx += 1
            
            # plot
            import cv2
            cv2.putText(image_np, text, (5, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
            cv2.imwrite(save_name, image_np)
# gyh 0709 plot badcase

这个就不展示效果啦,如果目标检测任务的badcase能弄明白,分类任务就很简单啦。

最近在做一些实例分割的项目,mask的badcase还在研究中...

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值