import numpy as np
def nms_multi_cls(boxes, bbox_threshold,overlap_threshold):
'''
:param boxes: [bbox_score,xmin,ymin,xmax,ymax,c1_score,c2_score,...,cn_score]
:return:
'''
# 如果没有检测到任何框,返回一个空列表
if len(boxes) == 0:
return []
boxes
## 筛选出置信度较高的框
boxes = boxes[boxes[:,0] > bbox_threshold]
# 初始化保留框列表和分数列表
result = []
## 截选出类别概率
cls_scores = boxes[:, 5:] ##
max_cls_index = np.argmax(cls_scores, axis=-1) ## 最大的类别分数的索引
max_cls_score = np.max(cls_scores, axis=-1) ## 最大类别分数
## detection [bbox_score,xmin,ymin,xmax,ymax,max_cls_score,max_cls]
detections = np.concatenate([boxes[:, :5], max_cls_score[:, np.newaxis], max_cls_index[:, np.newaxis]], axis=-1)
detections[:, 0] = detections[:, 0] * detections[:, 5]
detections = detections[detections[:, 0] > bbox_threshold]
uniq_cls = np.unique(max_cls_index) ## 检测结果存在的类别
for c in uniq_cls: ## 遍历每个类别,进行nms操作
det = detections[detections[:, -1] == c]
dets = nms_pi(det, thresh=overlap_threshold)
# Add max detections to outputs
if len(dets):
result.append(dets)
if len(result):
result = np.concatenate(result, axis=0)
else:
return []
return result
def nms_pi(dets, thresh=0.25):
"""
refer to:
https://github.com/facebookresearch/Detectron/blob/main/detectron/utils/cython_nms.pyx
Apply classic DPM-style greedy NMS.
"""
if dets.shape[0] == 0:
return dets[[], :]
scores = dets[:, 0]
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1) ## 计算左右框的面积
order = scores.argsort()[::-1] ## 按类别置信度由大到小排序
ndets = dets.shape[0]
suppressed = np.zeros((ndets), dtype=np.int) ## 用来记录检测框是否被抑制了
for _i in range(ndets):
i = order[_i] ## order,按检测框置信度由大到小取出检测它的索引
if suppressed[i] == 1:
continue
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i] ## 取出面积
for _j in range(_i + 1, ndets): ## 依次遍历剩下的框, 计算与当前最大的置信度的框之间的iou
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0.0, xx2 - xx1 + 1)
h = max(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= thresh:
suppressed[j] = 1
keep = np.where(suppressed == 0)[0]
dets = dets[keep, :]
### score,xmin,ymin,xmax,ymax,?
return dets
if __name__ == '__main__':
boxes = np.zeros((17, 10))
## [bbox_score,xmin,ymin,xmax,ymax,c1_score,c2_score,...,cn_score]
bbox_threshold = 0.25
overlap_threshold = 0.25
nms_multi_cls(boxes, bbox_threshold,overlap_threshold)
nms是抑制掉同类别间重合度较高的框