1、什么是非极大抑制
NMS是Non-Maximum Suppression的缩写,主要思想看下图
图一没有经过非极大抑制,图二是非极大抑制后的结果。可以看出,非极大抑制的基本思想就是
筛选出一定区域内属于同一种类得分最大的框。
2、非极大抑制的实现过程(以yolov3为例)
第一个维度是图片的数量。
第二个维度是所有的预测框。
第三个维度是所有的预测框的预测结果。
非极大抑制的执行过程如下所示:
①对所有图片进行循环。
②找出该图片中得分大于门限函数的框。在进行重合框筛选前就进行得分的筛选可以大幅度减少框的数量。
③判断第2步中获得的框的种类与得分。取出预测结果中框的位置与之进行堆叠。此时最后一维度里面的内容由5+num_classes变成了4+1+2,四个参数代表框的位置,一个参数代表预测框是否包含物体,两个参数分别代表种类的置信度与种类。
④对种类进行循环,非极大抑制的作用是筛选出一定区域内属于同一种类得分最大的框,对种类进行循环可以帮助我们对每一个类分别进行非极大抑制。
⑤根据得分对该种类进行从大到小排序。
⑥每次取出得分最大的框,计算其与其它所有预测框的重合程度,重合程度过大的则剔除。
(batch_size,all_boxes,4+1+num_classes) | |||||||
(1,num_boxes,x+y+w+h+1+num_classes) | |||||||
(num_boxes,4+1+numclasses) | YOLO | ||||||
(num_boxes,4+1+2) | x | y | w | h | 存在物体的概率 | 为猫的概率 | 为狗的概率 |
SSD | |||||||
x | y | w | h | 为背景概率 | 为猫的概率 | 为狗的概率 | |
batch_size:第一维度,输入进来图片的数量
all_boxes:第二维度,指向所有的预测框
4+1+num_classes:第三维度,表示框的位置,中心坐标加宽高
3、代码实现过程
①将框转换成左上角右下角的形式
②对每张图片进行循环
③对置信度进行判断
④取出框的种类和置信度
⑤进行unique操作,较少不必要的循环计算
⑥对每个类进行非极大抑制操作。每一次都选取得分最大的框,与其他所有框进行iou计算,若重合程度大,则剔除。
def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5,nms_thres=0.4):
# 将预测结果的格式转换成左上角右下角的格式。
# prediction [batch_size, num_anchors, 85]
box_corner = prediction.new(prediction.shape)
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]
output = [None for _ in range(len(prediction))]
for i, image_pred in enumerate(prediction):##对每张图片进行一个循环,但一般只有一张图片,循环只进行一次
#对置信度进行筛选
#先取出种类预测部分max
class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
#image_pred[:, 4] 先验框内部是否包含物体的置信度;class_conf 种类置信度。获得总的置信度,与门限比较
conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
image_pred = image_pred[conf_mask]
class_conf = class_conf[conf_mask]
class_pred = class_pred[conf_mask]
if not image_pred.size(0):
continue
#堆叠
detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)
#进行种类的筛选,表示出预测结果中包含了种类的情况,后续循环减少数据量
unique_labels = detections[:, -1].cpu().unique()
if prediction.is_cuda:
unique_labels = unique_labels.cuda()
detections = detections.cuda()
#获得某一类得分筛选后全部的预测结果
for c in unique_labels:
#对最后一维度内容进行判断,预测框种类进行判断
detections_class = detections[detections[:, -1] == c]
keep = nms(
detections_class[:, :4],
detections_class[:, 4] * detections_class[:, 5],
nms_thres
)##先取最大的,再取并集,取舍
max_detections = detections_class[keep]
# # 按照存在物体的置信度排序
# _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
# detections_class = detections_class[conf_sort_index]
# # 进行非极大抑制
# max_detections = []
# while detections_class.size(0):
# # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
# max_detections.append(detections_class[0].unsqueeze(0))
# if len(detections_class) == 1:
# break
# ious = bbox_iou(max_detections[-1], detections_class[1:])
# detections_class = detections_class[1:][ious < nms_thres]
# # 堆叠
# max_detections = torch.cat(max_detections).data
# Add max detections to outputs
output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
if output[i] is not None:
output[i] = output[i].cpu().numpy()
box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4]) / 2, output[i][:, 2:4] - output[i][:, 0:2]
output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
return output
本文参考一位大佬文章,一直在跟着这个老师学,此文只用作自己记录的笔记。
大佬原文链接:睿智的目标检测31——非极大抑制NMS与Soft-NMS_非极大抑制的作用_Bubbliiiing的博客-CSDN博客
对应B站视频链接:【深度学习小技巧-非极大抑制NMS及SOFT-NMS的实现(Bubbliiiing 深度学习 教程)-哔哩哔哩】 https://b23.tv/eQ4phVm睿智的目标检测31——非极大抑制NMS与Soft-NMS_非极大抑制的作用_Bubbliiiing的博客-CSDN博客