NMS(Ultralytics源码debug)

源码

def non_max_suppression(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
        classes=None,
        agnostic=False,
        multi_label=False,
        labels=(),
        max_det=300,
        nc=0,  # number of classes (optional)
        max_time_img=0.05,
        max_nms=30000,
        max_wh=7680,
):
    """
    Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.

    Args:
        prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
            containing the predicted boxes, classes, and masks. The tensor should be in the format
            output by a model, such as YOLO.
        conf_thres (float): The confidence threshold below which boxes will be filtered out.
            Valid values are between 0.0 and 1.0.
        iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
            Valid values are between 0.0 and 1.0.
        classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
        agnostic (bool): If True, the model is agnostic to the number of classes, and all
            classes will be considered as one.
        multi_label (bool): If True, each box may have multiple labels.
        labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
            list contains the apriori labels for a given image. The list should be in the format
            output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
        max_det (int): The maximum number of boxes to keep after NMS.
        nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
        max_time_img (float): The maximum time (seconds) for processing one image.
        max_nms (int): The maximum number of boxes into torchvision.ops.nms().
        max_wh (int): The maximum box width and height in pixels

    Returns:
        (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
            shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
            (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
    """

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
    if isinstance(prediction, (list, tuple)):  # YOLOv8 model in validation model, output = (inference_out, loss_out)
        prediction = prediction[0]  # select only inference output

    bs = prediction.shape[0]  # batch size
    nc = nc or (prediction.shape[1] - 4)  # number of classes
    nm = prediction.shape[1] - nc - 4
    mi = 4 + nc  # mask start index
    xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates

    # Settings
    # min_wh = 2  # (pixels) minimum box width and height
    time_limit = 0.5 + max_time_img * bs  # seconds to quit after
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)

    prediction = prediction.transpose(-1, -2)  # shape(1,84,6300) to shape(1,6300,84)
    prediction[..., :4] = xywh2xyxy(prediction[..., :4])  # xywh to xyxy

    t = time.time()
    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0  # width-height
        x = x[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            lb = labels[xi]
            v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
            v[:, :4] = xywh2xyxy(lb[:, 1:5])  # box
            v[range(len(lb)), lb[:, 0].long() + 4] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Detections matrix nx6 (xyxy, conf, cls)
        box, cls, mask = x.split((4, nc, nm), 1)

        if multi_label:
            i, j = torch.where(cls > conf_thres)
            x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
        else:  # best class only
            conf, j = cls.max(1, keepdim=True)
            x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        if n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        i = i[:max_det]  # limit detections

        # # Experimental
        # merge = False  # use merge-NMS
        # if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
        #     # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
        #     from .metrics import box_iou
        #     iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
        #     weights = iou * scores[None]  # box weights
        #     x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
        #     redundant = True  # require redundant detections
        #     if redundant:
        #         i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
            break  # time limit exceeded

    return output

解释

1、参数检查

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
    if isinstance(prediction, (list, tuple)):  # YOLOv8 model in validation mode, output = (inference_out, loss_out)
        prediction = prediction[0]  # select only inference output

2、初始化变量

    bs = prediction.shape[0]  # batch size
    nc = nc or (prediction.shape[1] - 4)  # number of classes
    nm = prediction.shape[1] - nc - 4 
    mi = 4 + nc  # mask start index
    xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates

nm应该是number of mc,是在实例分割时会用到的,主要看一下xc,筛选出置信度高于阈值的候选框

3、设置时间限制

  • 定义时间限制,超过该时间将停止处理。
  • 如果类别数量大于 1,才允许多标签模式。
    # Settings
    # min_wh = 2  # (pixels) minimum box width and height
    time_limit = 0.5 + max_time_img * bs  # seconds to quit after
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)

4、数据准备

  • 转置预测张量,使其形状变为 (batch_size, num_boxes, num_classes + 4 + num_masks)
  • 将边界框坐标从中心点和宽高 (xywh) 转换为左上角和右下角坐标 (xyxy)。
    prediction = prediction.transpose(-1, -2)  # shape(1,84,6300) to shape(1,6300,84)
    prediction[..., :4] = xywh2xyxy(prediction[..., :4])  # xywh to xyxy

 

5、初始化输出

  • 获取当前时间戳 t,用于监控处理时间。
  • 初始化一个列表,每个元素是一个空张量,用于存储每张图像的检测结果。
    t = time.time()
    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs

 

6、主循环

  • 遍历每一张图像。
  • 应用宽度和高度约束(这部分被注释掉了)。
  • 保留那些置信度高于阈值的候选框。
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0  # width-height
        x = x[xc[xi]]  # confidence

 执行 “x = x[xc[xi]]” 这一句之前

 执行 “x = x[xc[xi]]” 这一句之后

7、合并先验标签

  • 如果存在先验标签,将它们与预测结果合并。(我猜是为了验证阶段做准备)
        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            lb = labels[xi]
            v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
            v[:, :4] = xywh2xyxy(lb[:, 1:5])  # box
            v[range(len(lb)), lb[:, 0].long() + 4] = 1.0  # cls
            x = torch.cat((x, v), 0)

8、跳过空图像

  • 如果没有任何候选框,则跳过这张图像。
        # If none remain process next image
        if not x.shape[0]:
            continue

9、拆分预测

  • 将预测结果拆分为边界框坐标、类别概率和掩码。
        # Detections matrix nx6 (xyxy, conf, cls)
        box, cls, mask = x.split((4, nc, nm), 1)

 

10、处理多标签

  • 如果启用了多标签模式,对于每个类别,选择置信度高于阈值的候选框。(多标签就是指一个检测框可能对应多个类别的情况,假设我们有一个边界框包围了一个场景中的自行车,但自行车上还载着一个人。在这种情况下,如果我们希望模型能够同时识别出“自行车”和“人”这两个类别,那么就需要支持多标签。)
  • 否则,只保留每个候选框的最佳类别。 
        if multi_label:
            i, j = torch.where(cls > conf_thres) # i 是行索引即目标索引,j 是列索引即类别索引
            x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
        else:  # best class only
            conf, j = cls.max(1, keepdim=True)
            x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

 

11、类别过滤

  • 如果指定了要保留的类别,则仅保留这些类别的检测框。
        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

这个classes从cfg文件传进来的,我之前不知道这个变量,自己在模型推理结果Result的对象里修改,今天debugNMS,才发现有这个参数,所以我决定,要把cfg里的参数都弄明白是什么意思。 

12、裁剪检测框

  • 如果没有检测框,则跳过该图像。
  • 如果检测框数量过多,则按置信度排序并裁剪到最大数量。
        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        if n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes

13、执行 NMS

  • 如果启用了类别无关模式,则不考虑类别。也就是说,无论检测框属于哪个类别,只要它们之间有足够的重叠(根据IoU阈值),就会被视为重复并被抑制。
  • 执行 NMS 算法,返回保留的索引。(pytorch的nms是用c++写的,在文章最后我提供一个python版本的nms)
  • 限制保留的检测框数量。PS:这个参数给我坑惨了(事先没有读cfg所有参数的后果),在做密集目标的时候,我们总是有很多漏检,还以为是下采样的时候特征都丢失了,加大图像分辨率、加一个更小的检测头,啥方法都试过了,最后居然是这个参数的问题,,醉了。。。
        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        i = i[:max_det]  # limit detections

agnostic这个参数就是是否启用类别无关,如果是True,所有类别都置0,否则,所有类别都要乘max_wh,这个max_wh默认7680,是为了区分不同的类别,pytorch实现的nms只会滤掉iou超过阈值的那些框,不同类别加上一个偏移量,就可以按照类别来区分了。(刚开始我不理解为什么要加这个偏移量,如果加了这个偏移量,那框的位置不就改变了吗)

返回了一个索引,保存结果并返回

        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
            break  # time limit exceeded
    return output
import numpy as np
 
 
def intersection_over_union(boxA, boxB):
    # 计算两个边界框的交并比(IOU)
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
 
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
 
    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
 
    iou = interArea / float(boxAArea + boxBArea - interArea)
    return iou
 
 
def non_max_suppression(boxes, scores, iou_threshold=0.5):
    """
    实现非极大值抑制(NMS),输入是边界框和对应的分数,
    返回经过NMS处理后的边界框列表。
    """
    # 根据分数排序
    sorted_indices = np.argsort(scores)[::-1]
 
    keep_boxes = []
    while sorted_indices.size > 0:
        # 选择当前最高分的框
        idx = sorted_indices[0]
        keep_boxes.append(idx)
 
        # 计算当前框与其他所有框的IOU
        ious = np.array([intersection_over_union(boxes[idx], boxes[i]) for i in sorted_indices[1:]])
 
        # 删除与当前框IOU大于阈值的框
        remove_indices = np.where(ious > iou_threshold)[0] + 1  # +1是因为我们忽略了第一个元素(当前最高分的框)
        sorted_indices = np.delete(sorted_indices, remove_indices)
        sorted_indices = np.delete(sorted_indices, 0)  # 移除已经处理过的最高分框的索引
 
    return keep_boxes
 
 
# 示例用法
if __name__ == "__main__":
    # 单类别应用NMS
    # np.array()  创建numpy数组
    boxes = np.array([[10, 10, 40, 40], [11, 12, 43, 43], [9, 9, 39, 38]])  # [xmin, ymin, xmax, ymax]
    scores = np.array([0.9, 0.8, 0.7])  # 每个框的置信度
    iou_thresh = 0.1   # iou阈值
 
    # 应用NMS
    indices_to_keep = non_max_suppression(boxes, scores, iou_threshold=iou_thresh)
    print("保留的边界框索引:", indices_to_keep)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值