功能性模块:(3)NMS :cpu版和pytorch版
一、模块介绍
如果小伙伴们接触过检测方面的算法,应该对NMS不会很陌生,NMS(Non-Maximum Suppression),即非极大值抑制,特别原理性的LZ就不专门介绍了,网上太多了。总的功能是什么呢?就是我们在做检测的时候,假设画面中有一个人脸,但是由于检测算法的不同,可能会在一个人脸上给出多个满足要求的检测框,我们当然不能把这些框都使用上,就挑选一个最满足要求的检测框就可以了。具体的效果如下图所示:
二、代码实现
1. cpu版本实现
def py_cpu_nms(dets, thresh):
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
areas = (y2 - y1 + 1) * (x2 - x1 + 1)
scores = dets[:, 4]
keep = []
# 按照score先进行排序
index = scores.argsort()[::-1]
while index.size > 0:
i = index[0] # every time the first is the biggst, and add it directly
keep.append(i)
x11 = np.maximum(x1[i], x1[index[1:]]) # calculate the points of overlap
y11 = np.maximum(y1[i], y1[index[1:]])
x22 = np.minimum(x2[i], x2[index[1:]])
y22 = np.minimum(y2[i], y2[index[1:]])
w = np.maximum(0, x22 - x11 + 1) # the weights of overlap
h = np.maximum(0, y22 - y11 + 1) # the height of overlap
overlaps = w * h
ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
idx = np.where(ious <= thresh)[0]
index = index[idx + 1] # because index start from 1
return keep
demo
import numpy as np
def py_cpu_nms(dets, thresh):
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
areas = (y2 - y1 + 1) * (x2 - x1 + 1)
scores = dets[:, 4]
keep = []
index = scores.argsort()[::-1]
while index.size > 0:
i = index[0] # every time the first is the biggst, and add it directly
keep.append(i)
x11 = np.maximum(x1[i], x1[index[1:]]) # calculate the points of overlap
y11 = np.maximum(y1[i], y1[index[1:]])
x22 = np.minimum(x2[i], x2[index[1:]])
y22 = np.minimum(y2[i], y2[index[1:]])
w = np.maximum(0, x22 - x11 + 1) # the weights of overlap
h = np.maximum(0, y22 - y11 + 1) # the height of overlap
overlaps = w * h
ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
idx = np.where(ious <= thresh)[0]
index = index[idx + 1] # because index start from 1
return keep
import matplotlib.pyplot as plt
def plot_bbox(dets, c='k'):
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
plt.plot([x1, x2], [y1, y1], c)
plt.plot([x1, x1], [y1, y2], c)
plt.plot([x1, x2], [y2, y2], c)
plt.plot([x2, x2], [y1, y2], c)
# plt.title(" nms")
if __name__ == "__main__":
boxes = np.array([[100, 100, 210, 210, 0.72],
[250, 250, 420, 420, 0.8],
[220, 220, 320, 330, 0.92],
[100, 100, 210, 210, 0.72],
[230, 240, 325, 330, 0.81],
[220, 230, 315, 340, 0.9]])
print("boxes shape: ", boxes.shape)
boxes2 = np.random.rand() * boxes
boxes = np.concatenate((boxes, boxes2))
plt.figure(1)
ax1 = plt.subplot(1, 2, 1)
ax2 = plt.subplot(1, 2, 2)
plt.sca(ax1)
plt.title("ori")
plot_bbox(boxes, 'k') # before nms
keep = py_cpu_nms(boxes, thresh=0.5)
plt.sca(ax2)
plt.title("after nms")
plot_bbox(boxes[keep], 'r') # after nms
plt.savefig("./test_nms.jpg")
2. pytorch版本实现
def nms_torch(bboxes, scores, threshold=0.5):
x1 = bboxes[:, 0]
y1 = bboxes[:, 1]
x2 = bboxes[:, 2]
y2 = bboxes[:, 3]
areas = (x2 - x1) * (y2 - y1) # [N,] 每个bbox的面积
_, order = scores.sort(0, descending=True) # 降序排列
keep = []
while order.numel() > 0: # torch.numel()返回张量元素个数
if order.numel() == 1: # 保留框只剩一个
i = order.item()
keep.append(i)
break
else:
i = order[0].item() # 保留scores最大的那个框box[i]
keep.append(i)
# 计算box[i]与其余各框的IOU(思路很好)
xx1 = x1[order[1:]].clamp(min=x1[i]) # [N-1,]
yy1 = y1[order[1:]].clamp(min=y1[i])
xx2 = x2[order[1:]].clamp(max=x2[i])
yy2 = y2[order[1:]].clamp(max=y2[i])
inter = (xx2 - xx1).clamp(min=0) * (yy2 - yy1).clamp(min=0) # [N-1,]
iou = inter / (areas[i] + areas[order[1:]] - inter) # [N-1,]
idx = (iou <= threshold).nonzero().squeeze() # 注意此时idx为[N-1,] 而order为[N,]
if idx.numel() == 0:
break
order = order[idx + 1] # 修补索引之间的差值
return torch.LongTensor(keep) # Pytorch的索引值为LongTensor
demo
import numpy as np
import torch
import matplotlib.pyplot as plt
def py_cpu_nms(dets, thresh):
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
areas = (y2 - y1 + 1) * (x2 - x1 + 1)
scores = dets[:, 4]
keep = []
index = scores.argsort()[::-1]
while index.size > 0:
i = index[0] # every time the first is the biggst, and add it directly
keep.append(i)
x11 = np.maximum(x1[i], x1[index[1:]]) # calculate the points of overlap
y11 = np.maximum(y1[i], y1[index[1:]])
x22 = np.minimum(x2[i], x2[index[1:]])
y22 = np.minimum(y2[i], y2[index[1:]])
w = np.maximum(0, x22 - x11 + 1) # the weights of overlap
h = np.maximum(0, y22 - y11 + 1) # the height of overlap
overlaps = w * h
ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
idx = np.where(ious <= thresh)[0]
index = index[idx + 1] # because index start from 1
return keep
def nms_torch(bboxes, scores, threshold=0.5):
x1 = bboxes[:, 0]
y1 = bboxes[:, 1]
x2 = bboxes[:, 2]
y2 = bboxes[:, 3]
areas = (x2 - x1) * (y2 - y1) # [N,] 每个bbox的面积
_, order = scores.sort(0, descending=True) # 降序排列
keep = []
while order.numel() > 0: # torch.numel()返回张量元素个数
if order.numel() == 1: # 保留框只剩一个
i = order.item()
keep.append(i)
break
else:
i = order[0].item() # 保留scores最大的那个框box[i]
keep.append(i)
# 计算box[i]与其余各框的IOU(思路很好)
xx1 = x1[order[1:]].clamp(min=x1[i]) # [N-1,]
yy1 = y1[order[1:]].clamp(min=y1[i])
xx2 = x2[order[1:]].clamp(max=x2[i])
yy2 = y2[order[1:]].clamp(max=y2[i])
inter = (xx2 - xx1).clamp(min=0) * (yy2 - yy1).clamp(min=0) # [N-1,]
iou = inter / (areas[i] + areas[order[1:]] - inter) # [N-1,]
idx = (iou <= threshold).nonzero().squeeze() # 注意此时idx为[N-1,] 而order为[N,]
if idx.numel() == 0:
break
order = order[idx + 1] # 修补索引之间的差值
return torch.LongTensor(keep) # Pytorch的索引值为LongTensor
def plot_bbox(dets, c='k'):
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
plt.plot([x1, x2], [y1, y1], c)
plt.plot([x1, x1], [y1, y2], c)
plt.plot([x1, x2], [y2, y2], c)
plt.plot([x2, x2], [y1, y2], c)
# plt.title(" nms")
if __name__ == "__main__":
boxes = np.array([[100, 100, 210, 210, 0.72],
[250, 250, 420, 420, 0.8],
[220, 220, 320, 330, 0.92],
[100, 100, 210, 210, 0.72],
[230, 240, 325, 330, 0.81],
[220, 230, 315, 340, 0.9]])
print("boxes shape: ", boxes.shape)
boxes2 = np.random.rand() * boxes
boxes = np.concatenate((boxes, boxes2))
boxes_t = torch.tensor(boxes[:, :4]).cuda()
scores_t = torch.tensor(boxes[:, -1]).cuda()
keep = nms_torch(boxes_t, scores_t)
boxes_res = boxes_t[keep]
boxes_res = boxes_res.cpu().numpy()
plt.figure(1)
ax1 = plt.subplot(1, 2, 1)
ax2 = plt.subplot(1, 2, 2)
plt.sca(ax1)
plt.title("ori")
plot_bbox(boxes, 'k') # before nms
# keep = py_cpu_nms(boxes, thresh=0.5)
plt.sca(ax2)
plt.title("after nms")
plot_bbox(boxes_res, 'r') # after nms
plt.savefig("./test_nms_torch.jpg")
其实在使用centernet的时候基本上可以不使用nms了,但是对于某些特殊情况,还是需要增加nms的。对于检测来说,NMS还是非常重要的一个部分!
参考地址
https://zhuanlan.zhihu.com/p/54709759