手写 非极大值抑制(NMS)

1. NMS简介

        目标检测中的非极大值抑制(NMS)是一种用于去除重叠边界框的方法。具体来说,NMS通过计算每个边界框与其他所有边界框之间的交并比(IoU),来确定它们之间的相似度。然后,它会选择具有最高得分的边界框,并将其添加到最终结果中。接下来,NMS会删除与当前最高得分边界框具有高IoU的所有其他边界框。这个过程会一直重复,直到所有边界框都被处理完毕 。

2. IoU

        进行非极大值抑制(NMS)之前,必须了解边界框之间的交并比(IoU)。

import torch
# 假设边框的坐标为(x1, y1, x2, y2), box1维度为(N, 4), box2维度为(M, 4)
def BBox_IoU(box1, box2):
    # 交集左上角坐标
    lt = torch.max(box1[:,None,:2], box2[None,:,:2])
    # 交集右下角坐标
    rb = torch.min(box1[:,None,2:], box2[None,:,2:])
    
    # 交集的面积
    inter = (rb - lt).clamp(0).prod(2)
    
    box1_area = (box1[..., 2:] - box1[..., :2]).clamp(0).prod(1)
    box2_area = (box2[..., 2:] - box2[..., :2]).clamp(0).prod(1)
    # IoU
    iou = inter/(box1_area[:,None] + box2_area[None,:] - inter + 1e-9)
    return iou

3. NMS

# bbox维度为[N,4],scores维度为[N,], 均为tensor
def nms(bbox, scores, threshold=0.5, conf_threshold=0.001):
    # 将分数从大到小排序
    scores, idx = torch.sort(scores, descending=True)
    # 去掉分数 小于等于 conf_threshold 的 预测框,分数,以及相应的索引 
    score_mask = scores > conf_threshold
    bbox = bbox[idx][score_mask]
    scores = scores[score_mask]
    idx = idx[score_mask]
    # 保留下来的预测框索引
    keep_idx = []
    while (len(idx) > 0):
        # 如果只剩下一个框,直接保留下来
        if len(idx) == 1:
            keep_idx.append(idx[0])
            break
        # 保留下来得分最高的框
        keep_idx.append(idx[0])
        # 通过IoU获得可以保留下来的预测框的mask
        iou_mask = BBox_IoU(bbox[1:], bbox[0].unsqueeze(0)).view(-1) < threshold
        # 保留下来的预测框及其索引
        idx = idx[1:][iou_mask]
        bbox = bbox[1:][iou_mask]
    keep_idx = torch.as_tensor(keep_idx)
    return keep_idx
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值