非极大值抑制之代码解析(utils.py)
代码github地址:GitHub - eriklindernoren/PyTorch-YOLOv3: Minimal PyTorch implementation of YOLOv3
1. 非极大值抑制函数代码
# reference: https://github.com/eriklindernoren/PyTorch-YOLOv3/blob/f917503ffe4a21d2b1148d8cb13b89b834517d76/utils/utils.py
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
"""
剔除目标置信度小于conf_thres,以及通过非极大值抑制筛选预测的信息
Return detections(预测):
(x1, y1, x2, y2, object_conf, cls_conf, cls_idx)
"""
# 把(center x, center y, width, height)转换成(x1, y1, x2, y2)
prediction[..., :4] = xywh2xyxy(prediction[..., :4])
output = [None for _ in range(len(prediction))] # batch_size
# prediction: [batch_size x 10647 x 85]
for image_i, image_pred in enumerate(prediction):
# 筛选置信度大于conf_thres的行
image_pred = image_pred[image_pred[:, 4] >= conf_thres]
# 过滤没有大于conf_thres的img_pred
if not image_pred.size(0):
continue
# image_pred[:, 5:].max(1)[0]:取出每一行类别概率最大的值, 与筛选的目标置信度相乘
score = image_pred[:, 4] * image_pred[:, 5:].max(1)[0]
# image_pred按照score降序排列
image_pred = image_pred[(-score).argsort()]
# 得到每一行类别最大的概率与类别代号,keepdim: 保持与image_pred的维度相同
class_confs, class_preds = image_pred[:, 5:].max(1, keepdim=True)
# 获取detections:[x1,y1,x2,y2,object_conf,cls_conf,cls_idx],即预测的结果
detections = torch.cat((image_pred[:, :5], class_confs.float(), class_preds.float()), 1)
# 执行非极大值抑制
keep_boxes = []
while detections.size(0):
# 其他框与score最大的框(第一个框)做交并比
large_overlap = bbox_iou(detections[0, :4].unsqueeze(0), detections[:, :4]) > nms_thres
# 第一行与每一行的类别代号相等的条件
label_match = detections[0, -1] == detections[:, -1]
# 与第一个框交并比大于阈值,且类别代号相同的条件
invalid = large_overlap & label_match
# 满足上面条件的行,通过目标置信度对box做加权平均处理
weights = detections[invalid, 4:5]
# 加权平均合并,赋值给第一行
detections[0, :4] = (weights * detections[invalid, :4]).sum(0) / weights.sum()
keep_boxes += [detections[0]]
detections = detections[~invalid] # 取剩余部分的预测值
if keep_boxes:
output[image_i] = torch.stack(keep_boxes) # 非极大值抑制的结果
return output
2. 目标置信度过滤
(1) 首先需要把预测的值[cx,cy,w,h]转换为左上角与右下角坐标[x1,y1,x2,y2],然后由预测的目标置信度大于某个阈值过滤预测的每一行,代码:
# [cx,cy,w,h]->[x1,y1,x2,y2]
prediction[...,:4] = xywh2xyxy(prediction[...,:4])
# 通过目标置信项筛选
img_pred = img_pred[img_pred[:,4]>=conf_thres]
(2) 置信度阈值筛选后,数据输出如下:
3. 降序排列与预测重组
(1) 排序的值:每一行的最大类别概率与目标置信度概率的乘积,代码:
# image_pred[:, 5:].max(1)[0]:取出每一行类别概率最大的值, 与筛选的目标置信度相乘
score = img_pred[:,4] * img_pred[:,5:].max(1)[0]
# 按照score降序排序
img_pred = img_pred[(-score).argsort()]
(2) 预测重组,在[x1,y1,x2,y2,object_conf]的后面添加类别概率、类别代号:cls_conf、cls_idx。添加的数据是类别概率最大的一个值。代码:
# 得到每一行类别最大的概率与类别代号,keepdim: 保持与image_pred的维度相同
cls_confs, cls_idx = img_pred[:,5:].max(1,keepdim=True)
# 获取detections:[x1,y1,x2,y2,conf,cls_conf,cls_idx]
detections = torch.cat((img_pred[:, :5], cls_confs.float(), cls_idx.float()), 1)
(3) 执行结果:
4. 非极大值抑制:条件筛选
(1) 条件筛选:其他行与第一行交并比IoU大于阈值,且类别代号与第一行相等。代码:
# 其他行与第一行的交并比大于阈值,且类别代号与第一行相等的条件
large_overlap = bbox_iou(detections[0,:4].unsqueeze(0), detections[:,:4])>nms_thres
label_match = detections[0, -1] == detections[:, -1]
invalid = large_overlap & label_match
(2) 剔除满足invalid的行。
detections = detections[~invalid]
(3) 执行结果:
5. 非极大值抑制:目标置信度对box加权平均
(1) 目标置信度对box加权平均。
# 满足上面条件的行,通过目标置信度对box做加权平均处理
weights = detections[invalid, 4:5]
# 加权平均合并
detections[0, :4] = (weights * detections[invalid, :4]).sum(0) / weights.sum()
# 添加加权平均的detection
keep_boxes += [detections[0]]
(2) 执行结果:
(3) 最后output结果: