NMS源码解析

本文详细介绍了IOU概念在NMS中的应用,步骤包括选择最高得分box、计算IOU、筛选重叠小于阈值的box,以及多分类NMS的实现。通过源码解析,展示了如何使用PyTorch进行非极大值抑制操作以提高目标检测精度。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、IOU的概念 

 

二、NMS的算法原理

  1. 选取该类box中scores最大的一个,记为box_best,并保留它

  2. 计算box_best与其余的box的IOU

  3. 如果其IOU>threshold了,就舍弃这个box(因为可能这两个box表示同一目标,保留分数高的哪一个)

  4. 从最后剩余的boxes中,再找出最大scores的哪一个,如此循环往复,直到没有box为止

 三、源码解析

# ---------------------------
# 非极大值抑制(Non-Maximum Suppression,NMS),顾名思义就是抑制不是极大值的元素
# ---------------------------
import numpy as np
import torch

# 输入:
#   dets: 边界框tensor,每一个单元为(x1,y1,x2,y2,confidence),每个框都对应一个分数
#         (x1,y1)表示框的左上角坐标,(x2,y2)表示框的右下角坐标
#   thresh: iou过滤阈值
# 输出:nms处理过的边界框

def nms_cpu(dets, thresh):
    dets = dets.numpy()
    x1 = dets[:, 0]  # 取出所有的边界框左上角点的x坐标放入x1
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    scores = dets[:, 4]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)  # 计算所有边界框的面积
    # numpy的argsort()函数:返回数组值从小到大的索引值,
    # 再加上[::-1]返回数组值从大到小的索引值,
    # 也可以order = np.argsort(-score)
    order = scores.argsort()[::-1] #分数从大到小排列的索引值

    # 每次选出scores中最大的那个边界框
    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)  # 保留该类剩余box中得分最高的索引
        xx1 = np.maximum(x1[i], x1[order[1:]])#获取得分最高的边界框与其他所有框的的重叠区域的左上角x坐标
        yy1 = np.maximum(y1[i], y1[order[1:]])#标量和numpy取最值,结果是一个numpy
        xx2 = np.minimum(x2[i], x2[order[1:]])#此处是minimun,不是maximum。求得分最高的边界框与其他所有框的的重叠区域的右下角x坐标
        yy2 = np.minimum(y2[i], y2[order[1:]])

        # 计算重叠的面积,不重叠时面积为0
        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h#得到最大得分框和其他框的重叠面积

        # 计算IOU=重叠面积/(得分最大的框面积+当前的框面积-重叠面积)
        ovr = inter / (areas[i] + areas[order[1:]] - inter)
        # 保留iou小于等于阈值的边界框,其它则被过滤了
        # numpy.where() 有两种用法:
        # 1.np.where(condition, x, y):满足条件(condition),输出x,不满足输出y。
        # 2.np.where(condition):输出满足条件(即非0)元素的坐标(等价于numpy.nonzero)
        inds = np.where(ovr <= thresh)[0]#重叠面积小表示要保留
        # 因为ovr数组的长度比order数组少一个,所以这里要将所有下标后移一位,
        # 获得下一个目标区域的得分最高的一个的索引
        order = order[inds + 1]

    return torch.IntTensor(keep)
# 多分类非极大值抑制
def multiclass_nms(bboxes, scores, score_thresh=0.01, nms_thresh=0.45, pre_nms_topk=1000, pos_nms_topk=100):
    """
    bboxes存放的是所有batch所有class的框,shape是(num_batch,4)
    scores存放的是所有batch所有class的分数,shape是(num_batches,num_class)
    """
    batch_size = bboxes.shape[0]
    class_num = scores.shape[1]
    rets = []
    for i in range(batch_size):
        bboxes_i = bboxes[i]
        scores_i = scores[i]
        ret = []
        for c in range(class_num):
            scores_i_c = scores_i[c]
            keep_inds = nms(bboxes_i, scores_i_c, score_thresh, nms_thresh, pre_nms_topk, i=i, c=c)
            if len(keep_inds) < 1:
                continue
            keep_bboxes = bboxes_i[keep_inds]
            keep_scores = scores_i_c[keep_inds]
            keep_results = np.zeros([keep_scores.shape[0], 6])
            keep_results[:, 0] = c
            keep_results[:, 1] = keep_scores[:]
            keep_results[:, 2:6] = keep_bboxes[:, :]
            ret.append(keep_results)
        if len(ret) < 1:
            rets.append(ret)
            continue
        ret_i = np.concatenate(ret, axis=0)
        scores_i = ret_i[:, 1]
        if len(scores_i) > pos_nms_topk:
            inds = np.argsort(scores_i)[::-1]
            inds = inds[:pos_nms_topk]
            ret_i = ret_i[inds]

        rets.append(ret_i)

    return rets
# 非极大值抑制
def nms(bboxes, scores, score_thresh, nms_thresh, pre_nms_topk, i=0, c=0):
    """
    bboxes存放的是坐标,shape是(num_boxes,4)
    scores存放的是分数,shape是(num_boxes),bboxes和scores对应的索引是一致的
    """
    inds = np.argsort(scores)
    inds = inds[::-1]
    keep_inds = []
    while(len(inds) > 0):
        cur_ind = inds[0] #最大分数的索引
        cur_score = scores[cur_ind]
        # if score of the box is less than score_thresh, just drop it
        if cur_score < score_thresh:#判断最大得分是否小于score_thresh,如果是则丢弃
            break

        keep = True
        for ind in keep_inds:
            current_box = bboxes[cur_ind]
            remain_box = bboxes[ind]
            iou = box_iou_xyxy(current_box, remain_box)
            if iou > nms_thresh:
                keep = False
                break
        if i == 0 and c == 4 and cur_ind == 951:
            print('suppressed, ', keep, i, c, cur_ind, ind, iou)
        if keep:
            keep_inds.append(cur_ind)
        inds = inds[1:]

    return np.array(keep_inds)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值