源码
def non_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nc=0, # number of classes (optional)
max_time_img=0.05,
max_nms=30000,
max_wh=7680,
):
"""
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
Args:
prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
containing the predicted boxes, classes, and masks. The tensor should be in the format
output by a model, such as YOLO.
conf_thres (float): The confidence threshold below which boxes will be filtered out.
Valid values are between 0.0 and 1.0.
iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
Valid values are between 0.0 and 1.0.
classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
agnostic (bool): If True, the model is agnostic to the number of classes, and all
classes will be considered as one.
multi_label (bool): If True, each box may have multiple labels.
labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
list contains the apriori labels for a given image. The list should be in the format
output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
max_det (int): The maximum number of boxes to keep after NMS.
nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
max_time_img (float): The maximum time (seconds) for processing one image.
max_nms (int): The maximum number of boxes into torchvision.ops.nms().
max_wh (int): The maximum box width and height in pixels
Returns:
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
(x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
"""
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
bs = prediction.shape[0] # batch size
nc = nc or (prediction.shape[1] - 4) # number of classes
nm = prediction.shape[1] - nc - 4
mi = 4 + nc # mask start index
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
# Settings
# min_wh = 2 # (pixels) minimum box width and height
time_limit = 0.5 + max_time_img * bs # seconds to quit after
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
t = time.time()
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Detections matrix nx6 (xyxy, conf, cls)
box, cls, mask = x.split((4, nc, nm), 1)
if multi_label:
i, j = torch.where(cls > conf_thres)
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
if n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections
# # Experimental
# merge = False # use merge-NMS
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
# from .metrics import box_iou
# iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
# weights = iou * scores[None] # box weights
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
# redundant = True # require redundant detections
# if redundant:
# i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if (time.time() - t) > time_limit:
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
break # time limit exceeded
return output
解释
1、参数检查
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation mode, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
2、初始化变量
bs = prediction.shape[0] # batch size
nc = nc or (prediction.shape[1] - 4) # number of classes
nm = prediction.shape[1] - nc - 4
mi = 4 + nc # mask start index
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
nm应该是number of mc,是在实例分割时会用到的,主要看一下xc,筛选出置信度高于阈值的候选框
3、设置时间限制
- 定义时间限制,超过该时间将停止处理。
- 如果类别数量大于 1,才允许多标签模式。
# Settings
# min_wh = 2 # (pixels) minimum box width and height
time_limit = 0.5 + max_time_img * bs # seconds to quit after
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
4、数据准备
- 转置预测张量,使其形状变为
(batch_size, num_boxes, num_classes + 4 + num_masks)
。 - 将边界框坐标从中心点和宽高 (xywh) 转换为左上角和右下角坐标 (xyxy)。
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
5、初始化输出
- 获取当前时间戳
t
,用于监控处理时间。 - 初始化一个列表,每个元素是一个空张量,用于存储每张图像的检测结果。
t = time.time()
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
6、主循环
- 遍历每一张图像。
- 应用宽度和高度约束(这部分被注释掉了)。
- 保留那些置信度高于阈值的候选框。
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
执行 “x = x[xc[xi]]” 这一句之前
执行 “x = x[xc[xi]]” 这一句之后
7、合并先验标签
- 如果存在先验标签,将它们与预测结果合并。(我猜是为了验证阶段做准备)
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
x = torch.cat((x, v), 0)
8、跳过空图像
- 如果没有任何候选框,则跳过这张图像。
# If none remain process next image
if not x.shape[0]:
continue
9、拆分预测
- 将预测结果拆分为边界框坐标、类别概率和掩码。
# Detections matrix nx6 (xyxy, conf, cls)
box, cls, mask = x.split((4, nc, nm), 1)
10、处理多标签
- 如果启用了多标签模式,对于每个类别,选择置信度高于阈值的候选框。(多标签就是指一个检测框可能对应多个类别的情况,假设我们有一个边界框包围了一个场景中的自行车,但自行车上还载着一个人。在这种情况下,如果我们希望模型能够同时识别出“自行车”和“人”这两个类别,那么就需要支持多标签。)
- 否则,只保留每个候选框的最佳类别。
if multi_label:
i, j = torch.where(cls > conf_thres) # i 是行索引即目标索引,j 是列索引即类别索引
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
11、类别过滤
- 如果指定了要保留的类别,则仅保留这些类别的检测框。
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
这个classes从cfg文件传进来的,我之前不知道这个变量,自己在模型推理结果Result的对象里修改,今天debugNMS,才发现有这个参数,所以我决定,要把cfg里的参数都弄明白是什么意思。
12、裁剪检测框
- 如果没有检测框,则跳过该图像。
- 如果检测框数量过多,则按置信度排序并裁剪到最大数量。
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
if n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
13、执行 NMS
- 如果启用了类别无关模式,则不考虑类别。也就是说,无论检测框属于哪个类别,只要它们之间有足够的重叠(根据IoU阈值),就会被视为重复并被抑制。
- 执行 NMS 算法,返回保留的索引。(pytorch的nms是用c++写的,在文章最后我提供一个python版本的nms)
- 限制保留的检测框数量。PS:这个参数给我坑惨了(事先没有读cfg所有参数的后果),在做密集目标的时候,我们总是有很多漏检,还以为是下采样的时候特征都丢失了,加大图像分辨率、加一个更小的检测头,啥方法都试过了,最后居然是这个参数的问题,,醉了。。。
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections
agnostic这个参数就是是否启用类别无关,如果是True,所有类别都置0,否则,所有类别都要乘max_wh,这个max_wh默认7680,是为了区分不同的类别,pytorch实现的nms只会滤掉iou超过阈值的那些框,不同类别加上一个偏移量,就可以按照类别来区分了。(刚开始我不理解为什么要加这个偏移量,如果加了这个偏移量,那框的位置不就改变了吗)
返回了一个索引,保存结果并返回
output[xi] = x[i]
if (time.time() - t) > time_limit:
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
break # time limit exceeded
return output
import numpy as np
def intersection_over_union(boxA, boxB):
# 计算两个边界框的交并比(IOU)
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
iou = interArea / float(boxAArea + boxBArea - interArea)
return iou
def non_max_suppression(boxes, scores, iou_threshold=0.5):
"""
实现非极大值抑制(NMS),输入是边界框和对应的分数,
返回经过NMS处理后的边界框列表。
"""
# 根据分数排序
sorted_indices = np.argsort(scores)[::-1]
keep_boxes = []
while sorted_indices.size > 0:
# 选择当前最高分的框
idx = sorted_indices[0]
keep_boxes.append(idx)
# 计算当前框与其他所有框的IOU
ious = np.array([intersection_over_union(boxes[idx], boxes[i]) for i in sorted_indices[1:]])
# 删除与当前框IOU大于阈值的框
remove_indices = np.where(ious > iou_threshold)[0] + 1 # +1是因为我们忽略了第一个元素(当前最高分的框)
sorted_indices = np.delete(sorted_indices, remove_indices)
sorted_indices = np.delete(sorted_indices, 0) # 移除已经处理过的最高分框的索引
return keep_boxes
# 示例用法
if __name__ == "__main__":
# 单类别应用NMS
# np.array() 创建numpy数组
boxes = np.array([[10, 10, 40, 40], [11, 12, 43, 43], [9, 9, 39, 38]]) # [xmin, ymin, xmax, ymax]
scores = np.array([0.9, 0.8, 0.7]) # 每个框的置信度
iou_thresh = 0.1 # iou阈值
# 应用NMS
indices_to_keep = non_max_suppression(boxes, scores, iou_threshold=iou_thresh)
print("保留的边界框索引:", indices_to_keep)