# nms code
import torch
import numpy as np
def iou(box, boxes):
#先计算各自框面积
box_area = (box[4] - box[2]) * (box[3] - box[1])
boxes_area = (boxes[:,4] - boxes[:,2]) * (boxes[:,3] - boxes[:,1])
xx1 = torch.maximum(box[1], boxes[:,1])
yy1 = torch.maximum(box[2], boxes[:,2])
xx2 = torch.minimum(box[3], boxes[:,3])
yy2 = torch.minimum(box[4], boxes[:,4])
#判断是否有交集
w, h = torch.maximum(torch.tensor([0]), xx2 - xx1), torch.maximum(torch.tensor([0]), yy2 - yy1)
over_area = w * h
return over_area / (box_area + boxes_area - over_area)
def nms(boxes, thresh=0.35):
#根据boxes预测的置信度做降序排列
new_boxes = boxes[boxes[:, 0].argsort(descending=True)]
print('new_boxes', new_boxes)
# 取出置信度最大的框
keep_boxes = [] #定义保留最大框数组
while len(new_boxes) > 0:
max_box = new_boxes[0]
keep_boxes.append(max_box)
#存下除去 置信度最高框的 其他框
if len(new_boxes) > 1:
other_boxes = new_boxes[1:]
#将最大框与其他框做IOU判断,小于阈值保留,证明框的不是同一个物体
new_boxes = other_boxes[torch.where(iou(max_box, other_boxes) < thresh)]
else:
break
return torch.stack(keep_boxes)
if __name__ == "__main__":
boxes = torch.tensor([[0.9, 1,1,10,10], [0.4, 1,2,10,9], [0.8, 3,2,10,12]]) #[x1,y1,x2,y2]
print('nms(boxes)', nms(boxes))
手撕nms代码
最新推荐文章于 2024-07-17 09:58:36 发布