pytorch nonzero_NMS算法详解(附Pytorch实现代码)

85c0d8ba2a9441e71ee84fdac3549ac4.png

大约四个月前,从懵懵懂懂的本科生就莫名其妙的成为了研究生,随着自己的兴趣,选择了计算机视觉(Computer Vision),跟着导师入了目标检测的领域,这一切来得太快,一开始我也无所适从,由于自己有些基础,接触了第一个目标检测算法Single Shot Multibox Detector(SSD算法)。花了两个星期,大概读懂了论文,然后就去网上下载别人复现的代码开始调试,过了几天,可以跑起来了,很高兴,但是我内心其实很虚,因为对当时的我来说一切都是模模糊糊,于是我开始读源码,然后这种感觉有一个图可以很好地描述:

935f3ec73c4fa7fda966848e350a4d9e.png

各种矩阵运算和逻辑让我痛不欲生,并且大佬们的源码都是各种矩阵变换,运算,基本很少有for循环什么的,就很懵逼!曾经一度想要放弃读源码!

后来有一天,在刷知乎的时候刷到一个回答,一个旷视科技的大佬关于校招算法工程师的言论,有一句话:NMS都不会,做什么Detection! 深深的刺激到我了,我是做算法的,每个细节,每个底层我都有义务去掌握~

SO,今天来谈谈Non-maximum suppression(非极大值抑制)算法~~

IOU(区域交并比)

IOU的原称为Intersection over Union,也就是两个box区域的交集比上并集,下面的示意图就很好理解,用于确定两个框的位置像素距离~

a0f001adbb803d758066d9adc779f6b7.png

思路:(注意维度一致)

  • 首先计算两个box左上角点坐标的最大值和右下角坐标的最小值
  • 然后计算交集面积
  • 最后把交集面积除以对应的并集面积

其Pytorch源码为:(注意矩阵维度的变化)

# IOU计算
    # 假设box1维度为[N,4]   box2维度为[M,4]
    def iou(self, box1, box2):
        N = box1.size(0)
        M = box2.size(0)

        lt = torch.max(  # 左上角的点
            box1[:, :2].unsqueeze(1).expand(N, M, 2),   # [N,2]->[N,1,2]->[N,M,2]
            box2[:, :2].unsqueeze(0).expand(N, M, 2),   # [M,2]->[1,M,2]->[N,M,2]
        )

        rb = torch.min(
            box1[:, 2:].unsqueeze(1).expand(N, M, 2),
            box2[:, 2:].unsqueeze(0).expand(N, M, 2),
        )

        wh = rb - lt  # [N,M,2]
        wh[wh < 0] = 0   # 两个box没有重叠区域
        inter = wh[:,:,0] * wh[:,:,1]   # [N,M]

        area1 = (box1[:,2]-box1[:,0]) * (box1[:,3]-box1[:,1])  # (N,)
        area2 = (box2[:,2]-box2[:,0]) * (box2[:,3]-box2[:,1])  # (M,)
        area1 = area1.unsqueeze(1).expand(N,M)  # (N,M)
        area2 = area2.unsqueeze(0).expand(N,M)  # (N,M)

        iou = inter / (area1+area2-inter)
        return iou

其中:

  • torch.unsqueeze(1) 表示增加一个维度,增加位置为维度1
  • torch.squeeze(1) 表示减少一个维度

NMS(非极大抑制)

NMS算法一般是为了去掉模型预测后的多余框,其一般设有一个nms_threshold=0.5,具体的实现思路如下:

  1. 选取这类box中scores最大的哪一个,记为box_best,并保留它
  2. 计算box_best与其余的box的IOU
  3. 如果其IOU>0.5了,那么就舍弃这个box(由于可能这两个box表示同一目标,所以保留分数高的哪一个)
  4. 从最后剩余的boxes中,再找出最大scores的哪一个,如此循环往复
# NMS算法
    

其中:

  • torch.numel() 表示一个张量总元素的个数
  • torch.clamp(min, max) 设置上下限
  • tensor.item() 把tensor元素取出作为numpy数字

下面为算法测试效果(人头检测):

  • 不使用NMS算法(产生了40个预测框,都重复在一起)

9f19af1c284e5eb2c01a823fb66cba3e.png
  • 使用NMS算法之后(产生了5个预测框)

b3944764b9be946ff5209fc13e04c200.png

总结

NMS算法通常用于测试阶段,本文只是介绍NMS算法的底层实现,完整测试工程等我过段时间会在GitHub上进行开源的,方便大家学习,嗯嗯,现在是这样考虑的,希望大家多多交流学习!

目标检测中的数据增强算法实现:

TeddyZhang:目标检测:数据增强(Numpy+Pytorch)​zhuanlan.zhihu.com
d34ae923487b2c36e7d9be0fda1c7329.png
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值