理论参考:https://blog.csdn.net/a1103688841/article/details/89711120
源码:
import numpy as np
def py_cpu_nms(dets, thresh):
# 单独获取各个参数,以下参数shape = (5,)
x1 = dets[:,0]
y1 = dets[:,1]
x2 = dets[:,2]
y2 = dets[:,3]
scores = dets[:,4]
areas = (y2-y1+1) * (x2-x1+1)
print("areas.shape: {}".format(areas.shape))
print("areas: {}".format(areas))
keep = []
# 得分按照由高到低排序的索引, index.shape = (6,)
index = scores.argsort()[::-1]
print("index.shape: {}".format(index.shape))
print("index: {}".format(index))
while index.size >0:
# i为得分最高的索引
i = index[0]
# 将得分最高的索引追加到列表中
keep.append(i)
# 计算两个box左上角点坐标的最大值x11、y11和右下角坐标的最小值x22、y22
# x11、y11、x22、y22 shape = (5,)
x11 = np.maximum(x1[i], x1[index[1:]])
y11 = np.maximum(y1[i], y1[index[1:]])
x22 = np.minimum(x2[i], x2[index[1:]])
y22 = np.minimum(y2[i], y2[index[1:]])
print("index[1:]: {}".format(index[1:]))
print("x1[index[1:]]: {}".format(x1[index[1:]]))
print("x11: {}".format(x11))
print("x11.shape: {}".format(x11.shape))
# 当两个方框相交时,22-11最后得到w,h是正值
# 当两个方框不相交的时候,22-11最后得到w,h是负值,则设置为0
# w、h shape = (5,)
w = np.maximum(0, x22-x11+1)
h = np.maximum(0, y22-y11+1)
print("w: {}".format(w))
print("w.shape: {}".format(w.shape))
# 计算交集面积
# overlaps.shape = (5,)
overlaps = w * h
print("overlaps: {}".format(overlaps))
print("overlaps.shape: {}".format(overlaps.shape))
# 计算交并比
# ious.shape = (5,)
ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
print("ious.shape: {}".format(ious.shape))
print("ious: {}".format(ious))
# 得到满足阈值条件的ious中的索引(ious相比index缺少第一个最大值)
ious_idx = np.where(ious<=thresh)[0]
print("ious<=thres idx: {}".format(ious_idx))
# ious_idx+1得到在index中的索引
index = index[ious_idx + 1] # because index start from 1
print("index: {}".format(index))
return keep
if __name__ == "__main__":
boxes=np.array([[100,100,210,210,0.72], # 0
[250,250,420,420,0.8], # 1
[220,220,320,330,0.92], # 2
[100,100,210,210,0.72], # 3
[230,240,325,330,0.81], # 4
[220,230,315,340,0.9]]) # 5
keep = py_cpu_nms(boxes, thresh=0.7)
print("keep: {}".format(keep))
输出:
areas.shape: (6,)
areas: [ 12321. 29241. 11211. 12321. 8736. 10656.]
index.shape: (6,)
index: [2 5 4 1 3 0]
index[1:]: [5 4 1 3 0]
x1[index[1:]]: [ 220. 230. 250. 100. 100.]
x11: [ 220. 230. 250. 220. 220.]
x11.shape: (5,)
w: [ 96. 91. 71. 0. 0.]
w.shape: (5,)
overlaps: [ 9696. 8281. 5751. 0. 0.]
overlaps.shape: (5,)
ious.shape: (5,)
ious: [ 0.79664777 0.70984056 0.16573009 0. 0. ]
ious<=thres idx: [2 3 4]
index: [1 3 0]
index[1:]: [3 0]
x1[index[1:]]: [ 100. 100.]
x11: [ 250. 250.]
x11.shape: (2,)
w: [ 0. 0.]
w.shape: (2,)
overlaps: [ 0. 0.]
overlaps.shape: (2,)
ious.shape: (2,)
ious: [ 0. 0.]
ious<=thres idx: [0 1]
index: [3 0]
index[1:]: [0]
x1[index[1:]]: [ 100.]
x11: [ 100.]
x11.shape: (1,)
w: [ 111.]
w.shape: (1,)
overlaps: [ 12321.]
overlaps.shape: (1,)
ious.shape: (1,)
ious: [ 1.]
ious<=thres idx: []
index: []
keep: [2, 1, 3]