参考别人的代码,发现别人的代码比较繁琐,以下是精简后的NMS代码,代码可读性强,最后给出可视化图以及可视化代码(可视化是参考别人的代码)
NMS过程:
1 将各组box按照score降序排列
2 从score最大值开始,置为当前box,保存idex,然后依次遍历后面的box,计算与当前box的IOU值,若大于阈值,则抑制,不会输出
3 完成一轮遍历后,继续选择下一个非抑制的box作为当前box,重复步骤2
4 返回没有被抑制的index即符合条件的box
NMS在过程中需要计算iou,所以直接给出iou的函数:
def iou(bbox,gt):
#lt是两个框中间重叠框的最左边和最上边的坐标,rb是两个框中间重叠框的最右边和最下边的坐标
lt = np.maximum(bbox[:,None,:2],gt[:,:2]) # [N,M,2]
rb = np.minimum(bbox[:,None,2:4],gt[:,2:4]) # [N,M,2]
#wh是重叠框的宽和高,+1是因为求边长,边长就等于前后两端的坐标点相减且+1
wh = np.maximum(rb - lt + 1 , 0) # [N,M,2]
#求重叠框的面积
inter_area = wh[:,:,0] * wh[:,:,1] #[N,M]
#分别求两个框各自的面积
bbox_area = (bbox[:,2] - bbox[:,0] + 1) * (bbox[:,3] - bbox[:,1] + 1) #[N,]
gt_area = (gt[:,2] - gt[:,0] + 1) * (gt[:,3] - gt[:,1] + 1) #[M,]
#iou的公式,重叠框面积 / 两个框面积之和减去重叠框面积
iou = inter_area / (bbox_area[:,None] + gt_area - inter_area) #[N,M]
return iou
下面是NMS的代码:
def nms(bbox,thresh):
#得分bbox第五列是得分,前四列是x0,y0,x1,y1
score = bbox[:,4]
#对得分进行排序
order = np.argsort(score)
#记录结果值,每次保存得分最高的那个框的索引,最后再用bbox[keep]取出相应框
keep = []
#一直筛选到没有可用的框
while order.size > 0:
#取得分最高的框的索引,因为order是升序,所以最后一位是得分最高的
index = order[-1]
#保存得分最高的那个框的索引
keep.append(index)
#取出这个框
x = bbox[index]
#计算iou,x[None,:]是为了保持shape一致,squeeze(0)是去掉第一个维度,不去掉的话结果的shape是[1,5],再np.where就不对了,必须让其等于[5,]
sub_bbox_iou = iou(x[None,:],bbox[order[:-1]]).squeeze(0)
#筛选小于阈值的框,大于阈值的话,就和得分最高的那个框重叠了,所以保留不重叠的框
index_after = np.where(sub_bbox_iou < thresh)
#筛选剩下的框
order = order[index_after]
return keep
下面给出可视化的代码以及效果图:
import matplotlib.pyplot as plt
import numpy as np
#先初始化boxes
boxes=np.array([[100,100,210,210,0.72],
[250,250,420,420,0.8],
[220,220,320,330,0.92],
[100,100,210,210,0.72],
[230,240,325,330,0.81],
[220,230,315,340,0.9]])
def plot_bbox(dets, c='k'):
x1 = dets[:,0]
y1 = dets[:,1]
x2 = dets[:,2]
y2 = dets[:,3]
plt.plot([x1,x2], [y1,y1], c)
plt.plot([x1,x1], [y1,y2], c)
plt.plot([x1,x2], [y2,y2], c)
plt.plot([x2,x2], [y1,y2], c)
plt.title(" nms")
#在jupyter运行,或者在py运行
plt.figure(1)
ax1 = plt.subplot(1,2,1)
ax2 = plt.subplot(1,2,2)
plt.sca(ax1)
plot_bbox(boxes,'k') # before nms
keep = nms(boxes, 0.7)
plt.sca(ax2)
plot_bbox(boxes[keep], 'r')# after nms
左边的是没有进行nms的效果,右边的是进行nms后的效果。