YOLOv3-step2过滤多余的预测框

假定原图是416x416,则网络会输出的原始检测框个数:13*13*3  + 26*26* + 52*52*3 = 10647个。需要设置规则过滤掉多余的预测框。

规则简单来说,三步:

(1)删除目标概率(框内有目标的概率)小于阈值的检测框(小于则说明没有目标);

(2)遍历结果中的所有类别,对属于当前类别的所有预测框按目标概率从大到小排序;

(3)遍历排序后的预测框,计算当前预测框与其他同类别的预测框之间的IOU,删除IOU值大于阈值的预测框(大于则说明预测框重复了);

def write_results(prediction, confidence, num_classes, nms_conf=0.4):
    """
    subject our output to object score thresholding and Non-maximal suppression
    [b, 13*13*3 + 26*26*3 + 52*52*3, 85] -> [B, D, 8]. 8=[x1, y1, x2, y2, obj_score, cls_score, cls]


    :param prediction: Tensor. [b, num_bbox, attr_bbox]. num_bbox = 13*13*3 + 26*26*3 + 52*52*3, attr_bbox = 85
    :param confidence: objectness score threshold
    :param num_classes: 80, in our case
    :param nms_conf: the NMS IoU threshold
    :return: [b, D, (1+ 4+1+ 1+1)]
    a tensor of shape D x 8. Here D is the true detections in all of images, each represented by a row.
    Each detections has 8 attributes, namely,
    index of the image in the batch to which the detection belongs to,
    4 corner coordinates, objectness score, the score of class with maximum confidence, and the index of that class.
    """

    # 1, performing object score threshold
    # For each of the bounding box having a objectness score below a threshold,
    # we set the values of it's every attribute (entire row representing the bounding box) to zero.
    # (1) use float() to convert False/True to 0/1.
    # (2) use unsqueeze(2) to convert [1, 10647] to [1, 10647, 1]
    conf_mask = (prediction[:, :, 4] > confidence).float().unsqueeze(2)
    # 由于conf_mask非0即1,如果conf_mask有一个0,通过广播机制,生成一行都是0,再与prediction相乘,将该bbox所有属性归0
    prediction = prediction * conf_mask  # prediction.shape: [1, 10647, 85]

    # 2, Performing Non-maximum Suppression

    # (1)  x1y1_x2y2覆盖xy_wh
    """
    The bounding box attributes we have now are described by the center coordinates, as well as the height and width 
    of the bounding box. However, it's easier to calculate IoU of two boxes, 
    using coordinates of a pair of diagonal corners of each box."""
    # So, we transform the (center x, center y, height, width) attributes of our boxes,
    # to (top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y).
    box_corner = prediction.new(prediction.shape)  # type和device都与prediction保持一致
    box_corner[:, :, 0] = (prediction[:, :, 0] - prediction[:, :, 2] / 2)  # x = x_c - w/2
    box_corner[:, :, 1] = (prediction[:, :, 1] - prediction[:, :, 3] / 2)  # y = y_c - h/2
    box_corner[:, :, 2] = (prediction[:, :, 0] + prediction[:, :, 2] / 2)
    box_corner[:, :, 3] = (prediction[:, :, 1] + prediction[:, :, 3] / 2)
    prediction[:, :, :4] = box_corner[:, :, :4]  # x1y1_x2y2覆盖xy_wh

    """
    The number of true detections in every image may be different. For example, a batch of size 3 
    where images 1, 2 and 3 have 5, 2, 4 true detections respectively. Therefore, confidence thresholding and 
    NMS has to be done for one image at once. This means, we cannot vectorise the operations involved, and 
    must loop over the first dimension of prediction (containing indexes of images in a batch)."""
    batch_size = prediction.size(0)

    # indicate that we haven't initialized output, a tensor
    # we will use to collect true detections across the entire batch.
    write = False

    for ind in range(batch_size):
        image_pred = prediction[ind]  # image Tensor. [num_bbox, attr_bbox] = [10647, 85]
        # confidence threshholding
        # NMS
        """Once inside the loop, let's clean things up a bit. Notice each bounding box row has 85 attributes, 
        out of which 80 are the class scores. At this point, we're only concerned with the class score having the 
        maximum value. So, we remove the 80 class scores from each row, and instead add the index of the class having 
        the maximum values, as well the class score of that class. """
        max_conf, max_conf_score = torch.max(image_pred[:, 5:5 + num_classes], 1)  # 列维度求最大值,即每个bbox的目标最大可能的类别
        max_conf = max_conf.float().unsqueeze(1)  # values 加不加float(),没发现有啥区别。max_conf.shape = 10647 -> [10647, 1]
        max_conf_score = max_conf_score.float().unsqueeze(1)  # indices. 10647 -> [10647, 1]
        seq = (image_pred[:, :5], max_conf, max_conf_score)  # (x1, y1, x2, y2, object_score, max_conf, max_conf_indice)
        image_pred = torch.cat(seq, 1)  # -> [10647, 7]

        # Remember we had set the bounding box rows having a object confidence less than the threshold to zero?
        # Let's get rid of them.
        non_zero_ind = (torch.nonzero(image_pred[:, 4]))  # 找到有目标的bbox所在的行
        try:
            image_pred_ = image_pred[non_zero_ind.squeeze(), :].view(-1, 7)  # 有目标的预测属性,4个位置属性,1个目标概率,2个类别属性
        except:
            continue  # 当前图片没有目标,上面代码会报错,则continue.

        # For PyTorch 0.4 compatibility
        # Since the above code with not raise exception for no detection
        # as scalars are supported in PyTorch 0.4
        if image_pred_.shape[0] == 0:  # 当前图片没有目标
            continue
        #

        # Get the various classes detected in the image
        # image_pred_.shape: [10, 7]
        # image_pred_[:, -1].shape: 7
        img_classes = unique(image_pred_[:, -1])  # -1 index holds the class index. 获取当前图片有哪些类别

        for cls in img_classes:  # 遍历所有类别,对每个类别获取当前类别的所有dt
            # perform NMS

            # get the detections with one particular class
            cls_mask = image_pred_ * ((image_pred_[:, -1] == cls).float().unsqueeze(1))  # 不是当前类别的bbox属性值置为0
            class_mask_ind = torch.nonzero(cls_mask[:, -2]).squeeze()  # 获取当前类别的索引
            image_pred_class = image_pred_[class_mask_ind].view(-1, 7)  # 获取当前类别的检测结果,也可以cls_mask[class_mask_ind]

            # sort the detections, such that the entry with the maximum objectness
            # confidence is at the top
            conf_sort_index = torch.sort(image_pred_class[:, 4], descending=True)[1]  # sort返回值[0]和索引[1]
            image_pred_class = image_pred_class[conf_sort_index]  # [num_dt, 7]
            idx = image_pred_class.size(0)  # Number of detections

            for i in range(idx):  # 同一个类别有多个dt,按object score从大到小排序,然后遍历dt,其他dt和当前dt iou过大(nms_thresh)则去除其他dt.
                # Get the IOUs of all boxes that come after the one we are looking at
                # in the loop
                try:  # 当前的dt和其他同一类别的所有dt,求iou
                    ious = bbox_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i + 1:])  # [1,7], [2, 7],顺序随意
                except ValueError:
                    break

                except IndexError:  # 遍历结束, 上面的i+1会报错IndexError,或者image_pred_class只剩下一个,跳出循环。
                    break

                # Zero out all the detections that have IoU > nms_threshold,因为是从最大object score开始
                iou_mask = (ious < nms_conf).float().unsqueeze(1)  # [num_current_dt, 1], 其他dt与当前dt之间iou小的才保留
                image_pred_class[i + 1:] *= iou_mask  # 其他dt:image_pred_class[i + 1:]. [2, 7], 不符合的置为0

                # Remove the non-zero entries
                non_zero_ind = torch.nonzero(image_pred_class[:, 4]).squeeze()  # 4索引是object score
                image_pred_class = image_pred_class[non_zero_ind].view(-1, 7)  # 覆盖,只保留不为0的dt

            # batch_ind.shape: [num_real_dt, 1], 里面保存的值是图片ind
            batch_ind = image_pred_class.new(image_pred_class.size(0), 1).fill_(
                ind)  # Repeat the batch_id for as many detections of the class cls in the image
            seq = batch_ind, image_pred_class  # seq: 哪张图片,真实检测结果

            if not write:
                output = torch.cat(seq, 1)
                write = True
            else:
                out = torch.cat(seq, 1)
                output = torch.cat((output, out))

    try:
        return output
    except:
        return 0  # there's hasn't been a single detection in any images of the batch.

 

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Mr.Q

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值