import numpy as np
import torch
def box_iou(box1, box2):
"""
box1 :[N,4]
"""
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
inter = torch.abs((torch.min(a2, b2) - torch.max(a1, b1))).clamp(0).prod(2)
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter)
def nms_own(box, scores, iou_thresh):
score_index = scores.argsort(descending=True)
keep = []
while len(score_index) > 1:
keep.append(score_index[0])
max_score_box = box[score_index[0], :].unsqueeze(0)
next_score_box = box[score_index[1:], :]
iou = box_iou(max_score_box, next_score_box)
iou = iou.squeeze()
score_index = score_index[1:][iou < iou_thresh]
return keep
boxes = torch.tensor([[10, 20, 50, 80],
[15, 30, 55, 75],
[25, 35, 65, 70],
[10, 40, 50, 90],
[20, 25, 60, 85]], dtype=torch.float)
score = torch.tensor([0.8, 0.7, 0.6, 0.9, 0.5])
result = nms_own(boxes, score, 0.8)
print(result)
NMS代码
最新推荐文章于 2024-07-28 15:46:11 发布