NMS非极大值抑制与单分类与多分类的代码实战

NMS的理论

非极大值抑制(Non-Maximun Suppression,nms),就是抑制不是极大值的元素,可以理解为局部最大搜索。用于目标检测中,就是在某个局部范围内,选取置信度高的目标检测框,而抑制置信度低的误检框。一般来说,用在当解析模型输出到目标框时,目标框会非常多,具体数量由anchor数量决定,其中有很多重复的框定位到同一个目标,nms用来去除这些重复的框,获得真正的目标框。如下图所示,人、马、车上有很多框,通过nms,得到唯一的检测框。
在这里插入图片描述

例如在行人检测中,滑动窗口经提取特征,经分类器分类识别后,每个窗口都会得到一个分数。但是滑动窗口会导致很多窗口与其他窗口存在包含或者大部分交叉的情况。这时就需要用到NMS来选取那些邻域里分数最高(是行人的概率最大),并且抑制那些分数低的窗口。

在这里插入图片描述
NMS算法是目标检测中取出冗余检测框的常用算法,它的基本步骤为:

(1)选择某一类物体的所有的检测框和置信度,将其放到一个容器中

(2)对检测框的置信度进行降序排序

(3)选择容器中,检测框的置信度最大的bbox,将其保存下来,然后与容器中剩余的元素依次进行IOU计算

(4)如果IOU计算的结果大于置信度阈值的话,将该检测框及其置信度从容器中剔除出去

(5)重复3-4步骤,直至容器为空,才停止算法。

单分类nms代码

import numpy as np
 
bboxes = np.array([[100, 100, 210, 210, 0.72],
                   [250, 250, 420, 420, 0.8],
                   [220, 220, 320, 330, 0.92],
                   [100, 100, 210, 210, 0.72],
                   [230, 240, 325, 330, 0.81],
                   [220, 230, 315, 340, 0.9]])
 
 
# 进行非极大值抑制
def nms(iou_thresh=0.5, conf_threash = 0.5):
    # 基本思路:
    # (1) 将置信度进行降序排序,然后选择置信度最大的bbox,将其保存下来
    # (2) 将置信度最大的bbox和其他剩余的bbox进行交并比计算,将交并比大于阈值的bbox从这个集合中剔除出去
    # (3) 如果这个集合不为空的话,我们就重复上面的计算
    # 为了提高效率,我们保留bbox不动,最终保留的也都是bbox在原集合中的索引
    x1, y1, x2, y2, confidence = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3], bboxes[:, 4]
    # 计算面积
    area = (x2 - x1) * (y2 - y1)
    indices = confidence.argsort()[::-1]
    keep = []
    while indices.size > 0:
        idx_self, idx_other = indices[0], indices[1:]
        # 如果置信度小于阈值的话,那么后面的bbox就都不符合要求了,直接退出就行了
        if confidence[idx_self] < conf_threash:
            break
        keep.append(idx_self)
        # 计算交集
        xx1, yy1 = np.maximum(x1[idx_self], x1[idx_other]), np.maximum(y1[idx_self], y1[idx_other])
        xx2, yy2 = np.minimum(x2[idx_self], x2[idx_other]), np.minimum(y2[idx_self], y2[idx_other])
        w, h = np.maximum(0, xx2 - xx1), np.maximum(0, yy2 - yy1)
        intersection = w * h
        # 计算并集(两个面积和-交集)
        union = area[idx_self] + area[idx_other] - intersection
        iou = intersection / union
        # 只保留iou小于等于阈值的元素
        print(np.where(iou <= iou_thresh))
        keep_idx = np.where(iou <= iou_thresh)[0]
        indices = indices[keep_idx + 1]
 
    return np.array(keep)
 
 
# 进行非极大值抑制
if __name__ == '__main__':
    print(nms())

多分类nms代码

    def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
        #----------------------------------------------------------#
        #   将预测结果的格式转换成左上角右下角的格式。
        #   prediction  [batch_size, num_anchors, 85]
        #----------------------------------------------------------#
        box_corner          = prediction.new(prediction.shape)
        box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
        box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
        box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
        box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
        prediction[:, :, :4] = box_corner[:, :, :4]                  # 物体有无的置信度
        output = [None for _ in range(len(prediction))]
        for i, image_pred in enumerate(prediction):
            #----------------------------------------------------------#
            #   对种类预测部分取max
            #   class_conf  [num_anchors, 1]    种类置信度
            #   class_pred  [num_anchors, 1]    种类
            #----------------------------------------------------------#
            class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)

            #----------------------------------------------------------#
            #   利用置信度进行第一轮筛选
            #----------------------------------------------------------#
            conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()   # (物体有无的置信度*种类置信度)与阈值相比较
           #----------------------------------------------------------#
            #   根据置信度进行预测结果的筛选
            #----------------------------------------------------------#
            image_pred = image_pred[conf_mask]
            class_conf = class_conf[conf_mask]
            class_pred = class_pred[conf_mask]
            if not image_pred.size(0):
                continue
            #-------------------------------------------------------------------------#
            #   detections  [num_anchors, 7]
            #   7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
            #-------------------------------------------------------------------------#
            detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)

            #------------------------------------------#
            #   获得预测结果中包含的所有种类
            #------------------------------------------#
            unique_labels = detections[:, -1].cpu().unique()

            if prediction.is_cuda:
                unique_labels = unique_labels.cuda()
                detections = detections.cuda()

            for c in unique_labels:
                #------------------------------------------#
                #   获得某一类得分筛选后全部的预测结果
                #------------------------------------------#
                detections_class = detections[detections[:, -1] == c]

                #------------------------------------------#
                #   使用官方自带的非极大抑制会速度更快一些!
                #------------------------------------------#
                keep = nms(
                    detections_class[:, :4],
                    detections_class[:, 4] * detections_class[:, 5],
                    nms_thres
                )
                max_detections = detections_class[keep]
                
                # # 按照存在物体的置信度排序
                # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
                # detections_class = detections_class[conf_sort_index]
                # # 进行非极大抑制
                # max_detections = []
                # while detections_class.size(0):
                #     # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
                #     max_detections.append(detections_class[0].unsqueeze(0))
                #     if len(detections_class) == 1:
                #         break
                #     ious = bbox_iou(max_detections[-1], detections_class[1:])
                #     detections_class = detections_class[1:][ious < nms_thres]
                # # 堆叠
                # max_detections = torch.cat(max_detections).data
                
                # Add max detections to outputs
                output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
            
            if output[i] is not None:
                output[i]           = output[i].cpu().numpy()
                box_xy, box_wh      = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
                output[i][:, :4]    = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
        return output

单分类的nms,就是不区分物体的类别,按照是否有物体的概率看成置信度;
多分类的nms,就是要区分物体的类别,按照是否有物体的概率与每个类别的概率的乘积看成置信度(参考yolo4);

参考链接:https://blog.csdn.net/qq_38316300/article/details/120174900

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值