NMS(non_max_suppression)测试数据为YOLOv5模型输出数据(1, 9072, 31),一共有26类预测结果
1、数据形式说明[1, 9072,xywh, conf, cls]
其中xywh为(center x, center y, width, height)
x | y | w | h | 含有目标的概率 | 类别1的概率 | 类别2的概率 | …… |
---|---|---|---|---|---|---|---|
4001.6 | 8.5453 | 41.9846 | 19.4106 | 0.461647 | 0.0153131 | 0.9768897 | …… |
2、初步置信度筛选:xc = pre[…, 4] > conf_thres
conf_thres=0.4 xc.size[1, 9072]
先筛选置信度超过阈值conf_thres的预测结果conf
3、设置参数:min_wh, max_wh = 2, 4096,最大检索次数1000次
4、从pre中分离数据:for xi, x in enumerate(pre):x.shape:[9072,31], xi是索引(0)
5、筛除置信度小于0.4的box数据: x = x[xc[xi]]
confidence 获取所有置信度>0.4的box信息,不符合的筛掉
6、 x为置信度大于0.4的信息:x.shape:[33, 31]
7、置信度计算 :x[:, 5:] *= x[:, 4:5] //conf = obj_conf * cls_conf 论文里的公式:
x[:, 5:]对应9072个26列tensor,x[:, 4:5],对应9072个1列tensor,x[:, 5:] *= x[:, 4:5]则表示为26列的每个元素和一列的每个元素分别相乘得到26列的元素, 元素再赋值给x[:, 5:]的26列。经过这一步计算x的后面26列的值就是代表了目标的所在类别的confidence值了。
8、关于box:此时x[:, :4]数据形式为(center x, center y, width, height) 转化为 (x1, y1, x2, y2)= (top left x,top left y,bottom right x,bottom right y)
9、通过置信度选择符合置信度预值的box,重新赋值给x(此时的x为(xyxy, conf, cls),x的0到3列存放box,4列存放conf,5列存放class种类id:选取x[:, 5:] 每一行的最大值,以及最大值的索引;其中每一行的最大值为置信度,索引为class。
将box,conf,以及j 按照列cat成一个tensor作为网络的输出。[conf.view(-1) > conf_thres]则是筛选出confidence值大于conf_thres所有box
(其中,view()的作用相当于numpy中的reshape,重新定义矩阵的形状,参数为-1时代表动态调整形状)
x.shape:[33,6]
x.shape[0]是box个数,如果box为0下一张;如果box大于1000,则按照置信度降序排列,取出1000个数的box做nms
10、Batched NMS
c = x[:, 5:6] * max_wh
boxes = x[:, :4] + c
scores =x[:, 4] (置信度)
x[:, :4]表示box(从二维看第0,1,2,3列)
x[:, 4] 表示分数(从二维看第4列)
x[:, 5:6]表示类IDX(从二维看第5列)
max_wh这里是4096,这样偏移量仅取决于类IDX,并且足够大。
11、调用torch的nms
i = torchvision.ops.nms(boxes, scores, iou_thres)
①选取这类box中scores最大的哪一个,记为box_best,并保留它
②计算box_best与其余的box的IOU
③如果其IOU>iou_thres了,那么就舍弃这个box(由于可能这两个box表示同一目标,所以保留分数高的一个)
从最后剩余的boxes中,再找出最大scores的哪一个,如此循环往复。
注*:1)从所有候选框中选取置信度最高的预测边界框B1作为基准,然后将所有与B1的IOU超过预定阈值的其他边界框移除。
(这时所有边界框中B1为置信度最高的边界框且没有和其太过相似的边界框——非极大值置信度的边界框被抑制了)
2)从所有候选框中选取置信度第二高的边界框B2作为一个基准,将所有与B2的IOU超过预定阈值的其他边界框移除。
3)重复上述操作,直到所有预测框都被当做基准——这时候没有一对边界框过于相似。
12、输出:
output[xi] tensor([[x1, y1, w1, h1, conf1, cls1],
[x2, y2, w2, h2, conf2, cls2],
[x3, y3, w3, h3, conf3, cls3],
……])
输出形状固定为(n,6)
本文为个人对源代码理解,若有误区请多指教,私信必回