说多说少都是泪,被问到NMS如何实现的时候侃侃而谈,但是一旦从零手写就捉襟见肘,往后再学习的时候一定作到从零书写每一块代码。
import torch
# 如果是大框完全包住小框,则为over/Smin
# 1.先对置信度进行排序,2.取最大的置信度与后面的进行iou比较,筛选iou满足条件的(小于阈值的),
# 3.再从筛选后的框中选置信度最大的(前面哪个最大的已经保留到别处了,此时不再参选),与后面剩下的框进行iou比较.
# 以此类推,直到最后一部分只剩下一个框,结束nms!
def iou(box, boxes, isMin=False):
box_area = (box[2] - box[0]) * (box[3] - box[1])
boxes_areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
# 交集
xx1 = torch.max(box[0], boxes[:, 0])
yy1 = torch.max(box[1], boxes[:, 1])
xx2 = torch.min(box[2], boxes[:, 2])
yy2 = torch.min(box[3], boxes[:, 3])
w, h = torch.maximum(torch.Tensor([0]), xx2 - xx1), torch.maximum(torch.Tensor([0]), yy2 - yy1)
over_area = w * h
if isMin:
return over_area / torch.min(box_area, boxes_areas)
else:
return over_area / (box_area + boxes_areas - over_area)
def nms(boxes, thresh=0.3, isMin=False):
new_boxes = boxes[boxes[:, 0].argsort(descending=True)]
# argsort只能返回位置坐标[2,1,0],也就是原boxes根据置信度降序排序的原位置
# print(new_boxes)
keep_boxes = []
while len(new_boxes) > 0:
_box = new_boxes[0]
keep_boxes.append(_box)
if len(new_boxes) > 1:
_boxes = new_boxes[1:]
new_boxes = _boxes[iou(_box, _boxes, isMin) < thresh]
else:
break
return torch.stack(keep_boxes)
if __name__ == "__main__":
boxes = torch.tensor(([0.5, 1, 1, 10, 10], [0.5, 1, 2, 11, 11], [0.6, 1, 1, 5, 5], [0.9, 1, 1, 5, 5]))
# print(iou(box, boxes))
print(nms(boxes, 0.1))