YOLOv5中NMS处理过程解析

NMS(non_max_suppression)测试数据为YOLOv5模型输出数据(1, 9072, 31),一共有26类预测结果
1、数据形式说明[1, 9072,xywh, conf, cls]
其中xywh为(center x, center y, width, height)

xywh含有目标的概率类别1的概率类别2的概率……
4001.68.545341.984619.41060.4616470.01531310.9768897……

2、初步置信度筛选:xc = pre[…, 4] > conf_thres
conf_thres=0.4 xc.size[1, 9072]
先筛选置信度超过阈值conf_thres的预测结果conf
3、设置参数:min_wh, max_wh = 2, 4096,最大检索次数1000次
4、从pre中分离数据:for xi, x in enumerate(pre):x.shape:[9072,31], xi是索引(0)
5、筛除置信度小于0.4的box数据: x = x[xc[xi]]
confidence 获取所有置信度>0.4的box信息,不符合的筛掉
6、 x为置信度大于0.4的信息:x.shape:[33, 31]
7、置信度计算 :x[:, 5:] *= x[:, 4:5] //conf = obj_conf * cls_conf 论文里的公式:
在这里插入图片描述
x[:, 5:]对应9072个26列tensor,x[:, 4:5],对应9072个1列tensor,x[:, 5:] *= x[:, 4:5]则表示为26列的每个元素和一列的每个元素分别相乘得到26列的元素, 元素再赋值给x[:, 5:]的26列。经过这一步计算x的后面26列的值就是代表了目标的所在类别的confidence值了。
8、关于box:此时x[:, :4]数据形式为(center x, center y, width, height) 转化为 (x1, y1, x2, y2)= (top left x,top left y,bottom right x,bottom right y)
9、通过置信度选择符合置信度预值的box,重新赋值给x(此时的x为(xyxy, conf, cls),x的0到3列存放box,4列存放conf,5列存放class种类id:选取x[:, 5:] 每一行的最大值,以及最大值的索引;其中每一行的最大值为置信度,索引为class。
将box,conf,以及j 按照列cat成一个tensor作为网络的输出。[conf.view(-1) > conf_thres]则是筛选出confidence值大于conf_thres所有box
(其中,view()的作用相当于numpy中的reshape,重新定义矩阵的形状,参数为-1时代表动态调整形状)
x.shape:[33,6]
x.shape[0]是box个数,如果box为0下一张;如果box大于1000,则按照置信度降序排列,取出1000个数的box做nms
10、Batched NMS
c = x[:, 5:6] * max_wh
boxes = x[:, :4] + c
scores =x[:, 4] (置信度)
x[:, :4]表示box(从二维看第0,1,2,3列)
x[:, 4] 表示分数(从二维看第4列)
x[:, 5:6]表示类IDX(从二维看第5列)
max_wh这里是4096,这样偏移量仅取决于类IDX,并且足够大。
11、调用torch的nms
i = torchvision.ops.nms(boxes, scores, iou_thres)
①选取这类box中scores最大的哪一个,记为box_best,并保留它
②计算box_best与其余的box的IOU
③如果其IOU>iou_thres了,那么就舍弃这个box(由于可能这两个box表示同一目标,所以保留分数高的一个)
从最后剩余的boxes中,再找出最大scores的哪一个,如此循环往复。

注*:1)从所有候选框中选取置信度最高的预测边界框B1作为基准,然后将所有与B1的IOU超过预定阈值的其他边界框移除。
(这时所有边界框中B1为置信度最高的边界框且没有和其太过相似的边界框——非极大值置信度的边界框被抑制了)
2)从所有候选框中选取置信度第二高的边界框B2作为一个基准,将所有与B2的IOU超过预定阈值的其他边界框移除。
3)重复上述操作,直到所有预测框都被当做基准——这时候没有一对边界框过于相似。

12、输出:
output[xi] tensor([[x1, y1, w1, h1, conf1, cls1],
[x2, y2, w2, h2, conf2, cls2],
[x3, y3, w3, h3, conf3, cls3],
……])
输出形状固定为(n,6)
本文为个人对源代码理解,若有误区请多指教,私信必回

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值