【AI】Pytorch实现NMS

下图是预测出来的box,左边三个框被预测为类别1,右边两个框被预测为类别2,其中:

坐标格式为(x1,y1,x2,y2)

左边的2个框:

黑色记为1,坐标为(0,0,3,3)

棕色记为2,坐标为(0,0,3.5,2.5)

橙色记为3;坐标为(0,2,2,4)

右边的2个框:

橙色的记为4,坐标为(4,0,6,2)

浅黄色的记为5,坐标为(4.5,0,6.5,2)

 

 现在需要设定不同的阈值,对2个类别做NMS,得到如下结果:

 

 

首先计算IoU,参见:

【AI】Pytorch实现IoU_Dreamcatcher风的博客-CSDN博客

其次,做NMS:

def nms(
    pre_boxs,
    iou_thres,
    prob_thrsh,
    cor_format
):
    # pre_boxs=[[class, conf, x1, y1, x2, y2], # 每一行是一个box的信息
    #           [class, conf, x1, y1, x2, y2], 
    #           ...
    #          ]
    classes = torch.unique(pre_boxs[:,0])
    all_boxes = []
    for i in classes:
        boxes = [box for box in pre_boxs if box[1] > prob_thrsh and box[0] == i] # 对每一个类别都做NMS;先剔除置信度很低的box,例如背景
        boxes_after_nms = []
        boxes.sort(key = lambda x:x[1], reverse=True)
        while boxes:
            chosen_box = boxes.pop(0)
            boxes = [
                box
                for box in boxes
                if box[0] != chosen_box[0]
                or cal_iou(chosen_box[2:],box[2:],cor_format='corner') < iou_thres
            ]
            boxes_after_nms.append(chosen_box)
        all_boxes.append(boxes_after_nms) # 每个元素对应一个类别做NMS后的box
    
    return all_boxes

测试:

# 输入上面5个框的坐标
boxes = torch.tensor([
    [1,0.9,0,0,3,3],
    [1,0.6,0,0,3.5,2.5],
    [1,0.3,0,2,2,4],
    [2,0.5,4,0,6,2],
    [2,0.5,4.5,0,6.5,2]])
num_box = nms(boxes, iou_thres=0.5, prob_thrsh=0.1, cor_format='corner')

import cv2

img = cv2.imread('D:/1.jpg')
r = 0
g = 0
b = 0
for box in boxes:
    print(r,g,b)
    cv2.rectangle(img,(int(50*box[2]),int(50*box[3])),(int(50*box[4]),int(50*box[5])),(r, g, b),2)
    r += 1
    g += 60 
    b += 120

cv2.imwrite('D:/raw.jpg',img)

img1 = cv2.imread('D:/1.jpg')
r = 0
g = 0
b = 0
for i in range(len(num_box)):
    for j in range(len(num_box[i])):
        cv2.rectangle(img1,(int(50*num_box[i][j][2]),int(50*num_box[i][j][3])),(int(50*num_box[i][j][4]),int(50*num_box[i][j][5])),(r, g, b),2)
        r += 1
        g += 60
        b += 190
cv2.putText(img1,'IoU thresh = 0.5',(150,200), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 0), 2)
cv2.imwrite('D:/nms0.5.jpg',img1)

 

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值