Python实现NMS非极大值抑制算法

Python实现NMS非极大值抑制算法

NMS算法具体过程我就不多赘述,网上有很多该算法的讲解,下面直接上代码

bbox_iou是用于计算边框之间IoU值的代码,是直接copy的YoloV5中的代码

nms和nms_都是实现NMS算法的代码,返回值都是NMS过滤后的边框索引

# NMS非极大值过滤
import torch
import torchvision
from torch import Tensor
import math

def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        # 得到预测边框和真实边框左上角坐标和右下角坐标
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
        w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps

    # Intersection area 计算重叠区域面积
    # 重叠区域面积计算方式:左上角坐标选最大值,右下角坐标选最小值,用得到的右下角坐标减去左上角坐标得到重叠区域的长和宽,最后长乘宽得到重叠区域面积
    inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
            (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)

    # Union Area    计算两个边框合并面积 = 边框1面积 + 边框2面积 - 重叠区域面积
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU   计算IoU值
    iou = inter / union
    if CIoU or DIoU or GIoU:
        cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)  # convex (smallest enclosing box) width 最小外接矩形宽度
        ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)  # convex height 最小外接矩形高度
        if CIoU or DIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared 最小外接矩形的对角线距离
            rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center dist ** 2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
                with torch.no_grad():
                    alpha = v / (v - iou + (1 + eps))
                return iou - (rho2 / c2 + v * alpha)  # CIoU
            return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        return iou - (c_area - union) / c_area  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    return iou  # IoU

def nms(
        boxes: Tensor, scores: Tensor,
        score_threshold: float = 0.25,
        iou_threshold: float = 0.45,
        giou: bool = False, diou: bool = False, ciou: bool = False,
        soft_nms: bool = False,
        sigma: float = 0.5
):
    """
    计算nms
    :param boxes: 边框 [N,4] [x1,y1,x2,y2]
    :param scores: 置信度 [N]
    :param score_threshold: 最终返回边框时候的置信度阈值
    :param iou_threshold: 在更新边框的score的时候,仅对iou超过该阈值的边框进行对应的score更新
    :param giou: 是否实现GIoU-NMS/GIoU-soft-nms
    :param diou: 是否实现DIoU-NMS/DIoU-soft-nms
    :param ciou: 是否实现CIoU-NMS/CIoU-soft-nms
    :param soft_nms: 是否实现soft-nms
    :param sigma: 当更新方式为soft-nms时的超参数
    :return:
    """
    # 获取总的边框数目
    num_boxes = boxes.shape[0]
    # 获取device信息
    device = boxes.device
    # 因为后续要对scores进行改动,先获取一个scores的clone
    scores = scores.clone().to(device)
    # 创建一个indexes信息,并绑定到boxes上
    indexes = torch.arange(0, num_boxes).to(device).reshape(num_boxes, 1)
    boxes = torch.cat((boxes.clone(), indexes), dim = 1)

    # 获取score最大的边框,计算它与其他边框的IoU值,并根据IoU值更新其他边框的socre
    for i in range(num_boxes):
        # 获取当前边框的score,若小于置信度阈值,则该边框直接无视
        score = scores[i]
        if score < score_threshold:
            continue
        # 获取剩余边框中置信度最大的边框,并和当前边框交换
        if i != num_boxes - 1:
            max_score, max_index = torch.max(scores[i+1:], dim=0)
            if score < max_score:
                j = max_index.item() + i+1  # 得到最大边框在整个boxes中的位置
                boxes[i], boxes[j] = boxes[j].clone(), boxes[i].clone()
                scores[i], scores[j] = scores[j].clone(), scores[i].clone()
        # 交换完位置后,当前边框就是后面所有边框中置信度最大的,再计算后面边框与它的IoU值,方便剔除重叠边框
        iou = bbox_iou(boxes[i, :4], boxes[i+1:, :4], xywh=False, GIoU=giou, DIoU=diou, CIoU=ciou)[:, 0]
        # 得到IoU值后,根据IoU值选择直接将大于IoU阈值的置信度置为0获取采用soft_nms的方式进行减小置信度
        if soft_nms:
            iou = torch.where(iou < iou_threshold, torch.zeros_like(iou), iou)
            weights = torch.exp(- (iou * iou) / sigma)
        else:
            weights = torch.where(iou < iou_threshold, torch.ones_like(iou), torch.zeros_like(iou))
        # 更新后面边框的置信度信息
        scores[i+1:] = weights * scores[i+1:]

    # 返回置信度满足要求的边框下标
    remain = boxes[:, 4][scores > score_threshold].int()
    return remain


def nms_(
        boxes: Tensor, scores: Tensor,
        score_threshold: float = 0.25,
        iou_threshold: float = 0.45,
        giou: bool = False, diou: bool = False, ciou: bool = False,
        soft_nms: bool = False,
        sigma: float = 0.5
):
    """
    计算nms
    :param boxes: 边框 [N,4] [x1,y1,x2,y2]
    :param scores: 置信度 [N]
    :param score_threshold: 最终返回边框时候的置信度阈值
    :param iou_threshold: 在更新边框的score的时候,仅对iou超过该阈值的边框进行对应的score更新
    :param giou: 是否实现GIoU-NMS/GIoU-soft-nms
    :param diou: 是否实现DIoU-NMS/DIoU-soft-nms
    :param ciou: 是否实现CIoU-NMS/CIoU-soft-nms
    :param soft_nms: 是否实现soft-nms
    :param sigma: 当更新方式为soft-nms时的超参数
    :return:
    """
    # 获取总的边框数目
    N = boxes.shape[0]
    # 获取device信息
    device = boxes.device
    # 因为后续需要修改scores,所以获取一个scores的clone
    scores = scores.clone().to(device)
    # 创建boxes的索引信息,并和boxes.clone()绑定
    indexes = torch.arange(0, N).reshape(N, 1).to(device)
    boxes = torch.cat((boxes.clone(), indexes), dim=1)

    # 创建一个列表用于记录符合要求的box的下标
    remain = []
    # 循环,获取当前boxes中score最大的box,并计算其和其余boxes的IoU值,最后剔除IoU值大于阈值的boxes或者降低它们的置信度
    while True:
        # 获取最大score已经其对应的下标
        max_score, max_index = torch.max(scores[:], dim=0)
        # 如果当前score中最大的score都小于阈值,则直接结束循环
        if max_score < score_threshold:
            break
        # 记录下需要保留的box下标
        remain.append(max_index)
        # 计算其他boxes和该box的IoU值
        iou = bbox_iou(boxes[max_index, :4], boxes[:, :4], xywh=False, GIoU=giou, DIoU=diou, CIoU=ciou)[:, 0]
        # 根据IoU值更新其他边框的置信度信息
        if soft_nms:
            iou = torch.where(iou < iou_threshold, torch.zeros_like(iou), iou)
            weights = torch.exp(- (iou * iou) / sigma)
            # 注意:soft_nms此处要将max_index出的weight置为0
            weights[max_index] = 0
        else:
            weights = torch.where(iou < iou_threshold, torch.ones_like(iou), torch.zeros_like(iou))
        # 更新scores值
        scores = scores * weights
    return torch.tensor(remain).int()


if __name__ == '__main__':
    boxes = torch.tensor([
        [50, 50, 100, 100],
        [52, 54, 106, 98],
        [34, 32, 78, 38],
        [33, 32, 76, 38]
    ], dtype=torch.float)
    scores = torch.tensor([0.9, 0.8, 0.85, 0.86])
    indexes_nms = torchvision.ops.nms(boxes, scores, 0.4)
    print(indexes_nms)
    print(nms(boxes, scores, iou_threshold=0.4))
    print(nms(boxes, scores, iou_threshold=0.4, score_threshold=0.2, soft_nms=True))
    print(nms_(boxes, scores, iou_threshold=0.4))
    print(nms_(boxes, scores, iou_threshold=0.4, score_threshold=0.2, soft_nms=True))
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值