faster rcnn fpn_Faster-RCNN详解和torchvision源码解读(五):NMS(非极大值抑制)

在介绍NMS(Non-Maximum Suppression)之前先介绍IoU的概念。IoU可以用来衡量预测框的好坏。计算方法如下图,NMS利用IoU过滤掉重叠度高的anchor box。

bf2f7d0d56e0f60ae529290e5d45b15c.png
图片来自[Adrian_Rosebrock]_Deep_Learning_for_Computer_Vision-3

nms算法具体步骤如下

  1. 由于我们已经有每个box是否包含物体(objectness)的分数,我们按照这个分数对box从高到低排序。
  2. 然后我们对排好序的每一个box,计算出剩下的box和它的IoU,对于剩下的box,当IoU超过某个阀值(比如0.7)就将他去掉(suppress)

结果如下图

37cdc8a1cd57e12a4e175f10d4151ab3.png

图片出自:http://www.telesens.co/2018/03/11/object-detection-and-classification-using-r-cnns/

可以看出用这种方法,时间复杂度是O(n^2),计算速度会非常慢。所以torchvision将这部分实现转移到C++。在torchvision中使用nms如下

from torchvision.ops import boxes as box_ops
keep = box_ops.batched_nms(boxes, scores, lvl, nms_thresh)

如果想看C++的源码在:

pytorch/vision​github.com

路径:torchvision/csrc/cpu/nms_cpu.cpp

可能有些同学对C++不是很熟,所以我参照C++源码,重新用pytorch实现一遍,方便大家理解。

代码已上传到github:

https://github.com/VincentZhengg/learn_faster_rcnn​github.com

nms_implementations.ipynb

import torch
# 读取数据, 这些是前面保存的数据
# orig_boxes: 预测框
# orig_scores: 预测框分数(分数越大表示越有可能包含物体而不是背景)
# orig_lvl: 层级(FPN提取分层的特征值,这里都是1)

orig_boxes = torch.load('boxes.pt')
orig_scores = torch.load('scores.pt')
orig_lvl = torch.load('lvl.pt')
nms_thresh = 0.7

# 随机选取50个框
boxes = orig_boxes[1000:1050]
scores = orig_scores[1000:1050]
lvl = orig_lvl[1000:1050]

import cv2
img_path = '2007_000032.jpg'
img = cv2.imread(img_path)
img = cv2.resize(img, (800, 800))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

for box in boxes:
    startX, startY, endX, endY = int(box[0]), int(box[1]), int(box[2]), int(box[3])
    cv2.rectangle(img, (startX, startY), (endX, endY), (0, 0, 255), 2)

# 显示未使用nms前的box
import matplotlib.pyplot as plt
plt.figure(figsize=(20, 30))
plt.imshow(img)

dets = boxes

x1 = dets.select(1, 0).contiguous()
y1 = dets.select(1, 1).contiguous()
x2 = dets.select(1, 2).contiguous()
y2 = dets.select(1, 3).contiguous()

# box的面积
areas = (x2 - x1) * (y2 - y1)

# 我们根据分数对预测框进行排序,分数高的排在前面
order = scores.sort(0, descending=True)[1]

ndets = dets.size(0)

# suppressed标记box是否被过滤掉,1表示被suppress 
# keep保存需要保留的box
suppressed = torch.zeros(50, dtype=torch.int64)
keep = torch.zeros(50, dtype=torch.int64)

num_to_keep = 0
for _i in range(ndets):
    i = order[_i]
    if suppressed[i] == 1:
        continue
    keep[num_to_keep] = i
    num_to_keep += 1
    ix1 = x1[i]
    iy1 = y1[i]
    ix2 = x2[i]
    iy2 = y2[i]
    iarea = areas[i];
    for _j in range(_i, ndets):
        j = order[_j]
        if suppressed[j] == 1:
              continue
        xx1 = max(ix1, x1[j])
        yy1 = max(iy1, y1[j])
        xx2 = min(ix2, x2[j])
        yy2 = min(iy2, y2[j])
        
        w = max(0, xx2 - xx1)
        h = max(0, yy2 - yy1)
        inter = w * h;
        ovr = inter / (iarea + areas[j] - inter);
        if (ovr > nms_thresh):
            suppressed[j] = 1

keep.narrow(0, 0, num_to_keep)
after_nms_img = cv2.imread(img_path)
after_nms_img = cv2.resize(after_nms_img, (800, 800))
after_nms_img = cv2.cvtColor(after_nms_img, cv2.COLOR_BGR2RGB)

for i in keep:
    box = boxes[int(i)]
    startX, startY, endX, endY = int(box[0]), int(box[1]), int(box[2]), int(box[3])
    cv2.rectangle(after_nms_img, (startX, startY), (endX, endY), (0, 0, 255), 2)

plt.figure(figsize=(20, 30))
plt.imshow(after_nms_img)

谢谢阅读

铁马:Faster-RCNN详解和torchvision源码解读(六):roi pooling​zhuanlan.zhihu.com
铁马:Faster-RCNN详解和torchvision源码解读(七):roi aglin​zhuanlan.zhihu.com
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值