目录
第一种:普通NMS,需要安装numpy库
import numpy as np
from tqdm import tqdm
# 普通处理
def nms_1(boxes, threshold=0.5):
result = []
while len(boxes) > 0:
index = [i[4] for i in boxes] # 把所有的分数取出来
max_ = np.argmax(index) # 取出index中元素最大值所对应的索引,此时最大值0.92,其对应的位置索引值为2,(索引值默认从0开始)
max_cor = boxes[max_]
result.append(boxes[max_])
boxes = np.delete(boxes, max_, axis=0) # 把原来boxes中的max_索引对应的数据删除
res = []
for i in tqdm(range(len(boxes))):
if iou(max_cor[:-1], boxes[i][:-1]) < threshold:
res.append(boxes[i])
boxes = res
return result
# 软处理
def nms_2(boxes, threshold, λ): # soft_nms_2
score = boxes[:][4]
boxes = boxes[:][:4]
result = []
while len(boxes) > 0:
boxes = [i for i in boxes if i[-1] > score]
if len(boxes) == 0:
break
index = [i[-1] for i in boxes]
max_ = np.argmax(index)
max_cor = boxes[max_]
boxes = np.delete(boxes, max_, axis=0)
result.append(max_cor)