非极大值抑制 non-max-suppresion

所谓非极大值抑制,就是把不是最大的值给剔除掉。在深度学习之目标检测中,当模型预测一个物体时,会有很多框,这个时候,我们一般会选取置信度最大的作为输出。同时,把置信度小的与这个置信度最大的 IOU 大于一定阈值的给剔除掉,IOU大于一定阈值,就是两者重合度比较大了。

IOU的计算可以参照:

tensor与list版本的iou比较_shenjianhua005的专栏-CSDN博客

非极大值的代码,我是参照:

Machine-Learning-Collection/ML/Pytorch/object_detection/YOLOv3 at master · aladdinpersson/Machine-Learning-Collection · GitHub

这里主要对这个代码讲解下:

import torch

import xml.etree.ElementTree as ET

def convert_annotation(in_file):
    tree=ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)

    bbox = []
    for obj in root.iter('object'):
        cls = obj.find('name').text
        xmlbox = obj.find('bndbox')
        bbox.append([cls, int(xmlbox.find('xmin').text), int(xmlbox.find('xmax').text), int(xmlbox.find('ymin').text), int(xmlbox.find('ymax').text)])

    return bbox


def iou(chosen, box_val):

    xmin = torch.max(chosen[0], box_val[0])
    ymin = torch.max(chosen[1], box_val[1])
    xmax = torch.min(chosen[2], box_val[2])
    ymax = torch.min(chosen[3], box_val[3])

    inter_area = (xmax-xmin) * (ymax-ymin)

    chosen_area = (chosen[2]-chosen[0]) * (chosen[3]-chosen[1])
    box_val_area = (box_val[2]-box_val[0]) * (box_val[3]-box_val[1])

    val = inter_area / (1e-16 + chosen_area + box_val_area - inter_area)
    return val


def non_max_suppression(pred, prob_threshold, iou_threshold):
    assert type(pred) == list

    non_suppression_after = []  # 这个用于存放需要显示出来的结果
    pred = [val for val in pred if val[1] > prob_threshold]   # 这个是把一些阈值比较低的给过滤掉,留下阈值高的进行IOU计算
    box = sorted(pred, key=lambda x:x[1], reverse=True) # 排序,这样就是从高到低
    while box:
        chose = box.pop(0) # 第一个值,这样box中就少一个

        # 这个里面进行了类别判断,只有相同类别才进行IOU判断,这里是把 IOU 重合度较小的留下来,然后依次与后面比较,总体而言,还是比较灵活。
        box = [box_val for box_val in box if chose[0] != box_val[0] or iou(torch.tensor(chose[2:]), torch.tensor(box_val[2:])) < iou_threshold]

        non_suppression_after.append(chose)

    return non_suppression_after


if __name__ == '__main__':
    # 第一个是类别,第二个是置信度,后面就是left-top, right-down
    pred = [['1', 0.8, 138, 518, 259, 525], ['1', 0.6, 119, 474, 316, 770]]

    prob_threshold = 0.5
    iou_threshold = 0.5
    nms_after = non_max_suppression(pred, prob_threshold, iou_threshold)
    if nms_after:
        print(nms_after)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值