PyTorch实现NMS算法

20 篇文章 0 订阅
16 篇文章 0 订阅

PyTorch实现NMS算法

介绍

参考链接1:NMS 算法源码实现
参考链接2: Python实现NMS(非极大值抑制)对边界框进行过滤。
目标检测算法(主流的有 RCNN 系、YOLO 系、SSD 等)在进行目标检测任务时,可能对同一目标有多次预测得到不同的检测框,非极大值抑制(NMS) 算法则可以确保对每个对象只得到一个检测,简单来说就是“消除冗余检测”。

示例代码

以下代码实现在 PyTorch 中实现非极大值抑制(NMS)。这个函数接受三个参数:boxes(边界框),scores(每个边界框的得分),和 iou_threshold(交并比阈值)。假设输入的边界框格式为 [x1, y1, x2, y2],其中 (x1, y1) 是左上角坐标,(x2, y2) 是右下角坐标。

import torch

def nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float):
    """
    Perform Non-Maximum Suppression (NMS) on bounding boxes.

    Args:
        boxes (torch.Tensor): A tensor of shape (N, 4) containing the bounding boxes
                              of shape [x1, y1, x2, y2], where N is the number of boxes.
        scores (torch.Tensor): A tensor of shape (N,) containing the scores of the boxes.
        iou_threshold (float): The IoU threshold for suppressing boxes.

    Returns:
        torch.Tensor: A tensor of indices of the boxes to keep.
    """
    # Get the areas of the boxes
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    areas = (x2 - x1) * (y2 - y1)

    # Sort the scores in descending order and get the sorted indices
    _, order = scores.sort(0, descending=True)

    keep = []
    while order.numel() > 0:
        if order.numel() == 1:
            i = order.item()
            keep.append(i)
            break
        else:
            i = order[0].item()
            keep.append(i)

        # Compute the IoU of the kept box with the rest
        xx1 = torch.max(x1[i], x1[order[1:]])
        yy1 = torch.max(y1[i], y1[order[1:]])
        xx2 = torch.min(x2[i], x2[order[1:]])
        yy2 = torch.min(y2[i], y2[order[1:]])

        w = torch.clamp(xx2 - xx1, min=0)
        h = torch.clamp(yy2 - yy1, min=0)
        inter = w * h
        iou = inter / (areas[i] + areas[order[1:]] - inter)

        # Keep the boxes with IoU less than the threshold
        inds = torch.where(iou <= iou_threshold)[0]
        order = order[inds + 1]

    return torch.tensor(keep, dtype=torch.long)

代码工作原理:

  1. 计算每个边界框的面积。
  2. 根据得分对边界框进行降序排序。
  3. 依次选择得分最高的边界框,并计算它与其他边界框的 IoU。
  4. 保留 IoU 小于阈值的边界框,并继续处理剩余的边界框。
  5. 返回保留的边界框的索引。
  • 6
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch是一个基于Python的开源机器学习库,可用于创建深度学习模型。SuperPoint是一种用于图像特征点检测和描述的深度学习网络模型。 要使用PyTorch实现SuperPoint,首先需要定义模型的结构。SuperPoint模型由主要的卷积神经网络(CNN)和后处理的非极大值抑制(NMS)组成。 在PyTorch中,可以使用nn.Module类来创建SuperPoint模型的定义。在主要的CNN中,可以使用卷积层、批量归一化层和非线性激活函数,例如ReLU。还可以使用池化层来减小特征图的尺寸。 在模型的输出中,可以使用softmax激活函数将特征点的分类概率归一化,用于确定每个像素是否为关键点。此外,还可以使用另一个卷积层来生成每个特征点的描述信息。 在训练SuperPoint模型时,可以使用已标记的图像数据集来进行有监督学习。可以定义损失函数,例如交叉熵损失,来衡量分类概率的准确性和描述信息的相似性。 在PyTorch中,可以使用torchvision库来加载训练数据集,并使用torch.optim库来定义优化器,例如随机梯度下降(SGD)来更新模型的权重和偏置。 在模型训练完成后,可以使用SuperPoint模型来检测和描述新的图像。可以将待检测的图像输入模型中,获取每个像素的分类概率,并使用NMS算法筛选出特征点。 总之,使用PyTorch实现SuperPoint需要定义模型的结构,加载训练数据集,定义损失函数和优化器,以及应用模型进行特征点检测和描述。通过训练和应用SuperPoint模型,可以从图像中提取出具有高级语义信息的关键点。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值