最详细NMS代码解释--python--pytorch

#NMS

'''
image_pred_ 是一个图片上所有bbox组成的二维varibale
(all_bboxs_num, 4+1+1+1) 4:左上右下坐标,1:检测置信度,1:最大分类概率,1:最大分类概率的类别序号--代表哪一个类
对每一个类进行NMS
'''
for cls in img_classes:
'''
得到该类的所有检测框,剔除非该类的检测框
不为该类的bbox设为0 为该类的设为1
用非0 行索引提取非0 bbox 剔除0 bbox
用nonzero()[] 提取非0 bbox的行索引

'''
    cls_mask = image_pred_ * (image_pred_[:,-1] == cls).float().unsqueeze(1)
    class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze()
    image_pred_class = image_pred_[class_mask_ind].view(-1,7)

'''
2排序--用置信度进行降序排序 sort()[1] 返回行索引,用行索引重排pred
idx 该类的bbox的个数
'''
    conf_sort_index = torch.sort(image_pred_class[:,4]), descending = True)[1]
    image_pred_class = image_pred_class[conf_sort_index]
    idx = imge_pred_class.size(0)
    for i in range(idx)

'''
3计算iou
计算pre[i] 与 pre[i+1 : ]的ious
剩余最后一个后就break跳出
'''
        try:
            ious = bbox_iou(image_pred_class[i], image_pred_class[i+1: ])
       except: ValueError:
           break
       except: IndexError:
           break

'''
根据iou剔除大于阈值的bbox
iou 大于阈值的设为0 ,小于阈值的设为1 ,
用nonzero()提取置信度非零行索引
用非零索引提取非零bbox
'''
        iou_mask = (ious < nms_conf).float().unsqueeze(1)
        image_pred_class[i+1: ] *= iou_mask
        non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze()
        image_pred_class = image_pred_class[non_zero_ind].view(-1,7)

  #为该类别NMS挑选后的bbox 前面打上batch的序号

    batch_ind = image_pred_class.new(image_pred_class.size(0),1).fill_(ind)
    seq = batch_ind, image_pred_class
    #cat所有的结果--> output
    if not write:    #新Tensor值不能为空 初始化
        output = torch.cat(seq,1)
        write = True
    else:
        out = torch.cat(seq,1)        #在列维度上拼接--横向拼接
        output = torch.cat((output,out))  #在行维度上拼接--纵向拼接

#计算IOU

def bbox_iou(bbox1,bbox2):
    if coordinate == 'center':      #坐标为中心点坐标x,y,w,h
        bbox1_[:,0] = bbox[:,0] - bbox[:,2]/2
        bbox1_[:,1] = bbox[:,1] - bbox[:,3]/2
        bbox1_[:,2] = bbox[:,0] + bbox[:,2]/2
        bbox1_[:,1] = bbox[:,1] + bbox[:,3]/2

        bbox2_[:,0] = bbox[:,0] - bbox[:,2]/2
        bbox2_[:,1] = bbox[:,1] - bbox[:,3]/2
        bbox2_[:,2] = bbox[:,0] + bbox[:,2]/2
        bbox2_[:,1] = bbox[:,1] + bbox[:,3]/2

        b1_x1, b1_y1, b1_x2, b1_y2 = bbox1_[:,0], bbox1_[:,1], bbox1_[:,2], bbox1_[:,3]
        b2_x1, b2_y1, b2_x2, b2_y2 = bbox2_[:,0], bbox2_[:,1], bbox2_[:,2], bbox1_[:,3]



    if coordinate == 'angle':  #坐标为左上右下定点坐标x1,y1,x2,y2
        b1_x1, b1_y1, b1_x2, b1_y2 = bbox1[:,0], bbox1[:,1], bbox1[:,2], bbox1[:,3]
        b2_x1, b2_y1, b2_x2, b2_y2 = bbox2[:,0], bbox2[:,1], bbox2[:,2], bbox1[:,3]

#计算交集区域面积

    inter_x1 = torch.max(b1_x1,b2_x1)
    inter_y1 = torch.max(b1_y1,b2_y1)
    inter_x2 = torch.max(b1_x2,b2_x2)
    inter_y2 = torch.max(b1_y2,b2_y2)
    inter_arae = torch.clamp(inter_x2 - inter_x1+1,min=0)*torch.clamp(inter_y2 - intery1 +1 ,min =0)

#计算并集区域面积
    b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
    b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
    all_area = b1_area + b2_area - inter_area
    iou = inter_area/all_area

return iou

 

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值