在介绍NMS(Non-Maximum Suppression)之前先介绍IoU的概念。IoU可以用来衡量预测框的好坏。计算方法如下图,NMS利用IoU过滤掉重叠度高的anchor box。
nms算法具体步骤如下
- 由于我们已经有每个box是否包含物体(objectness)的分数,我们按照这个分数对box从高到低排序。
- 然后我们对排好序的每一个box,计算出剩下的box和它的IoU,对于剩下的box,当IoU超过某个阀值(比如0.7)就将他去掉(suppress)
结果如下图
图片出自:http://www.telesens.co/2018/03/11/object-detection-and-classification-using-r-cnns/
可以看出用这种方法,时间复杂度是O(n^2),计算速度会非常慢。所以torchvision将这部分实现转移到C++。在torchvision中使用nms如下
from torchvision.ops import boxes as box_ops
keep = box_ops.batched_nms(boxes, scores, lvl, nms_thresh)
如果想看C++的源码在:
pytorch/visiongithub.com路径:torchvision/csrc/cpu/nms_cpu.cpp
可能有些同学对C++不是很熟,所以我参照C++源码,重新用pytorch实现一遍,方便大家理解。
代码已上传到github:
https://github.com/VincentZhengg/learn_faster_rcnngithub.comnms_implementations.ipynb
import torch
# 读取数据, 这些是前面保存的数据
# orig_boxes: 预测框
# orig_scores: 预测框分数(分数越大表示越有可能包含物体而不是背景)
# orig_lvl: 层级(FPN提取分层的特征值,这里都是1)
orig_boxes = torch.load('boxes.pt')
orig_scores = torch.load('scores.pt')
orig_lvl = torch.load('lvl.pt')
nms_thresh = 0.7
# 随机选取50个框
boxes = orig_boxes[1000:1050]
scores = orig_scores[1000:1050]
lvl = orig_lvl[1000:1050]
import cv2
img_path = '2007_000032.jpg'
img = cv2.imread(img_path)
img = cv2.resize(img, (800, 800))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
for box in boxes:
startX, startY, endX, endY = int(box[0]), int(box[1]), int(box[2]), int(box[3])
cv2.rectangle(img, (startX, startY), (endX, endY), (0, 0, 255), 2)
# 显示未使用nms前的box
import matplotlib.pyplot as plt
plt.figure(figsize=(20, 30))
plt.imshow(img)
dets = boxes
x1 = dets.select(1, 0).contiguous()
y1 = dets.select(1, 1).contiguous()
x2 = dets.select(1, 2).contiguous()
y2 = dets.select(1, 3).contiguous()
# box的面积
areas = (x2 - x1) * (y2 - y1)
# 我们根据分数对预测框进行排序,分数高的排在前面
order = scores.sort(0, descending=True)[1]
ndets = dets.size(0)
# suppressed标记box是否被过滤掉,1表示被suppress
# keep保存需要保留的box
suppressed = torch.zeros(50, dtype=torch.int64)
keep = torch.zeros(50, dtype=torch.int64)
num_to_keep = 0
for _i in range(ndets):
i = order[_i]
if suppressed[i] == 1:
continue
keep[num_to_keep] = i
num_to_keep += 1
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i];
for _j in range(_i, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0, xx2 - xx1)
h = max(0, yy2 - yy1)
inter = w * h;
ovr = inter / (iarea + areas[j] - inter);
if (ovr > nms_thresh):
suppressed[j] = 1
keep.narrow(0, 0, num_to_keep)
after_nms_img = cv2.imread(img_path)
after_nms_img = cv2.resize(after_nms_img, (800, 800))
after_nms_img = cv2.cvtColor(after_nms_img, cv2.COLOR_BGR2RGB)
for i in keep:
box = boxes[int(i)]
startX, startY, endX, endY = int(box[0]), int(box[1]), int(box[2]), int(box[3])
cv2.rectangle(after_nms_img, (startX, startY), (endX, endY), (0, 0, 255), 2)
plt.figure(figsize=(20, 30))
plt.imshow(after_nms_img)
谢谢阅读
铁马:Faster-RCNN详解和torchvision源码解读(六):roi poolingzhuanlan.zhihu.com