NMS的算法步骤如下:
# INPUT:所有预测出的bounding box (bbx)信息(坐标和置信度confidence), IOU阈值(大于该阈值的bbx将被移除)
for object in all objects:
(1) 获取当前目标类别下所有bbx的信息
(2) 将bbx按照confidence从高到低排序,并记录当前confidence最大的bbx
(3) 计算最大confidence对应的bbx与剩下所有的bbx的IOU,移除所有大于IOU阈值的bbx
(4) 对剩下的bbx,循环执行(2)和(3)直到所有的bbx均满足要求(即不能再移除bbx)
需要注意的是,NMS是对所有的类别分别执行的。举个栗子,假设最后预测出的矩形框有2类(分别为cup, pen),在NMS之前,每个类别可能都会有不只一个bbx被预测出来,这个时候我们需要对这两个类别分别执行一次NMS过程。
我们用python编写NMS代码,假设对于一张图片,所有的bbx信息已经保存在一个字典中,保存形式如下:
predicts_dict: {"cup": [[x1_1, y1_1, x2_1, y2_1, scores1], [x1_2, y1_2, x2_2, y2_2, scores2], ...], "pen": [[x1_1, y1_1, x2_1, y2_1, scores1], [x1_2, y1_2, x2_2, y2_2, scores2], ...]}
即目标的位置和置信度用列表储存,每个列表中的一个子列表代表一个bbx信息。详细的代码如下:
import numpy as np
def non_max_suppress(predicts_dict, threshold=0.2):
"""
implement non-maximum supression on predict bounding boxes.
Args:
predicts_dict: {"stick": [[x1, y1, x2, y2, scores1], [...]]}.
threshhold: iou threshold
Return:
predicts_dict processed by non-maximum suppression
"""
for object_name, bbox in predicts_dict.items(): #对每一个类别的目标分别进行NMS
bbox_array = np.array(bbox, dtype=np.float)
## 获取当前目标类别下所有矩形框(bounding box,下面简称bbx)的坐标和confidence,并计算所有bbx的面积
x1, y1, x2, y2, scores = bbox_array[:,0], bbox_array[:,1], bbox_array[:,2], bbox_array[:,3], bbox_array[:,4]
areas = (x2-x1+1) * (y2-y1+1)
#print("areas shape = ", areas.shape)
## 对当前类别下所有的bbx的confidence进行从高到低排序(order保存索引信息)
order = scores.argsort()[::-1]
print("order = ", order)
keep = [] #用来存放最终保留的bbx的索引信息
## 依次从按confidence从高到低遍历bbx,移除所有与该矩形框的IOU值大于threshold的矩形框
while order.size > 0:
i = order[0]
keep.append(i) #保留当前最大confidence对应的bbx索引
## 获取所有与当前bbx的交集对应的左上角和右下角坐标,并计算IOU(注意这里是同时计算一个bbx与其他所有bbx的IOU)
xx1 = np.maximum(x1[i], x1[order[1:]]) #当order.size=1时,下面的计算结果都为np.array([]),不影响最终结果
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
inter = np.maximum(0.0, xx2-xx1+1) * np.maximum(0.0, yy2-yy1+1)
iou = inter/(areas[i]+areas[order[1:]]-inter)
print("iou =", iou)
print(np.where(iou<=threshold)) #输出没有被移除的bbx索引(相对于iou向量的索引)
indexs = np.where(iou<=threshold)[0] + 1 #获取保留下来的索引(因为没有计算与自身的IOU,所以索引相差1,需要加上)
print("indexs = ", type(indexs))
order = order[indexs] #更新保留下来的索引
print("order = ", order)
bbox = bbox_array[keep]
predicts_dict[object_name] = bbox.tolist()
predicts_dict = predicts_dict
return predicts_dict
在YOLO源代码general.py可见:
def non_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nm=0, # number of masks
):
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
device = prediction.device
mps = 'mps' in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()
bs = prediction.shape[0] # batch size
nc = prediction.shape[2] - nm - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates
# Settings
# min_wh = 2 # (pixels) minimum box width and height
max_wh = 7680 # (pixels) maximum box width and height
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 0.5 + 0.05 * bs # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
t = time.time()
mi = 5 + nc # mask start index
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
v[:, :4] = lb[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
# Box/Mask
box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
mask = x[:, mi:] # zero columns if no masks
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = x[:, 5:mi].max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if mps:
output[xi] = output[xi].to(device)
if (time.time() - t) > time_limit:
LOGGER.warning(f'WARNING:NMS time limit {time_limit:.3f}s exceeded')
break # time limit exceeded
return output