nms非极大抑制

该博客详细解析了YOLOv5中非极大值抑制(NMS)的原理和代码实现,包括获取预测框信息、筛选高置信度框、去除重复检测和限制处理时间等步骤。博主深入解释了模型输出结构、NMS参数设置以及关键代码段的作用,有助于读者理解目标检测模型的后处理过程。
摘要由CSDN通过智能技术生成

基本原理

  1. 获取batch size,类别数,候选框基本信息
  2. 将传入的所有候选框先通过置信度进行排序
  3. 取置信度最高的作为作为第一个正解,然后通过iou_thres剔除重复检测
  4. 小于iou_thres阈值的重复第2,3步直到得到全部正解

原理图

请添加图片描述

代码解析

part.1 参数

(1)prediction
pred = model(img, augment=False)[0]

调用关系是:model—>forward—>_forward_once
_forward_once的返回值 : 一个tensor list 存放三个元素 [bs, anchor_num, grid_w, grid_h, xywh+c+20classes]

 def forward(self, x, augment=False, profile=False, visualize=False):
        return self._forward_once(x, profile, visualize)  # single-scale inference, train

 def _forward_once(self, x, profile=False, visualize=False):
        y, dt = [], []  # outputs
        for m in self.model:
         #前向推理每一层结构   m.i=index   m.f=from   m.type=类名   m.np=number of params
        # if not from previous layer   m.f=当前层的输入来自哪一层的输出  s的m.f都是-1
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            x = m(x)  # run
            y.append(x if m.i in self.save else None)  # save output
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
        return x
       

其实你可以打开yolov5s.yaml,m.f != -1只有4个concat操作和1个Detect操作
e.g.

[[-1, 6], 1, Concat, [1]],  # cat backbone P4
[[17, 20, 23], 1, Detect, [nc, anchors]]

concat操作和Detect操作的第一位都是list类型

if m.f != -1:  # if not from previous layer
	x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers

这段代码本小白也是理解一段时间,其实是等价于

y=[]
if isinstance(m.f, int):
	x=y.append(m.f)
	else:
		for j in m.f:
			if j == -1:x=y.append(x)
			else:x=y.append(y[j])

说白了就是将取出输入不是或不只是从上一层 的对应的层的结果,准备后面的进入对应m的forward()

predict的结构

nc = prediction.shape[2] - 5 # number of classes

prediction是网络模型的直接输出
输出其shape是(1, 50000, 7), 1表示的是图片的个数,50000表示是网络预测的候选框的个数,7表示一组数其意义如下:
在这里插入图片描述

由这个可以看出来,nc可以得到网络预测的类别个数。

xc = prediction[, 4] > conf_thres # 目标框中含有目标的概率值

由上图可知,prediction[…, 4]是第5个框,代表含有目标的概率,> conf_thres使结果变成bool类型,表示是否判断当前目标框中含有目标。prediction[…, 4]的shape是(1, 50000),因此xc也是shape为(1, 50000),类型为bool的一个tensor。

在后面通过检查其是否为TRUE(判断目标框中含有目标):

# If none remain process next image
        if not x.shape[0]:
            continue

一些变量的含义

 max_wh = 7680  # (像素)最大框宽和高度
    max_nms = 30000  # torchvision.ops.nms()中的最大候选框数量
    time_limit = 0.3 + 0.03 * bs  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

定义输出

output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  • prediction.shape[0]确定了输出是长度为1,包含六个张量的list

part.2完整代码解读

很多内容参照,学习yolov5 nms 源码理解
以下代码取自yolov5 v6.x

def non_max_suppression(prediction,
                        conf_thres=0.25,
                        iou_thres=0.45,
                        classes=None,
                        agnostic=False,
                        multi_label=False,
                        labels=(),
                        max_det=300):
    """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """

    bs = prediction.shape[0]  # batch size
    nc = prediction.shape[2] - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

    # Settings
    # min_wh = 2  # (pixels) minimum box width and height
    max_wh = 7680  # (pixels) maximum box width and height
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 0.3 + 0.03 * bs  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

    t = time.time()
    output = [torch.zeros((0, 6), device=prediction.device)] * bs
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
        x = x[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            lb = labels[xi]
            v = torch.zeros((len(lb), nc + 5), device=x.device)
            v[:, :4] = lb[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(lb)), lb[:, 0].long() + 5] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Compute conf
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        box = xywh2xyxy(x[:, :4])

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = x[:, 5:].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Apply finite constraint
        # if not torch.isfinite(x).all():
        #     x = x[torch.isfinite(x).all(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        elif n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        if i.shape[0] > max_det:  # limit detections
            i = i[:max_det]
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')
            break  # time limit exceeded

    return output

细节上

如果刷性能分的话,

multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)

没有必要将所有类别都记入,这样会增加时间成本

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值