Python实现NMS非极大值抑制算法
NMS算法具体过程我就不多赘述,网上有很多该算法的讲解,下面直接上代码
bbox_iou是用于计算边框之间IoU值的代码,是直接copy的YoloV5中的代码
nms和nms_都是实现NMS算法的代码,返回值都是NMS过滤后的边框索引
import torch
import torchvision
from torch import Tensor
import math
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
if xywh:
(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
else:
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
union = w1 * h1 + w2 * h2 - inter + eps
iou = inter / union
if CIoU or DIoU or GIoU:
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)
if CIoU or DIoU:
c2 = cw ** 2 + ch ** 2 + eps
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4
if CIoU:
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - (rho2 / c2 + v * alpha)
return iou - rho2 / c2
c_area = cw * ch + eps
return iou - (c_area - union) / c_area
return iou
def nms(
boxes: Tensor, scores: Tensor,
score_threshold: float = 0.25,
iou_threshold: float = 0.45,
giou: bool = False, diou: bool = False, ciou: bool = False,
soft_nms: bool = False,
sigma: float = 0.5
):
"""
计算nms
:param boxes: 边框 [N,4] [x1,y1,x2,y2]
:param scores: 置信度 [N]
:param score_threshold: 最终返回边框时候的置信度阈值
:param iou_threshold: 在更新边框的score的时候,仅对iou超过该阈值的边框进行对应的score更新
:param giou: 是否实现GIoU-NMS/GIoU-soft-nms
:param diou: 是否实现DIoU-NMS/DIoU-soft-nms
:param ciou: 是否实现CIoU-NMS/CIoU-soft-nms
:param soft_nms: 是否实现soft-nms
:param sigma: 当更新方式为soft-nms时的超参数
:return:
"""
num_boxes = boxes.shape[0]
device = boxes.device
scores = scores.clone().to(device)
indexes = torch.arange(0, num_boxes).to(device).reshape(num_boxes, 1)
boxes = torch.cat((boxes.clone(), indexes), dim = 1)
for i in range(num_boxes):
score = scores[i]
if score < score_threshold:
continue
if i != num_boxes - 1:
max_score, max_index = torch.max(scores[i+1:], dim=0)
if score < max_score:
j = max_index.item() + i+1
boxes[i], boxes[j] = boxes[j].clone(), boxes[i].clone()
scores[i], scores[j] = scores[j].clone(), scores[i].clone()
iou = bbox_iou(boxes[i, :4], boxes[i+1:, :4], xywh=False, GIoU=giou, DIoU=diou, CIoU=ciou)[:, 0]
if soft_nms:
iou = torch.where(iou < iou_threshold, torch.zeros_like(iou), iou)
weights = torch.exp(- (iou * iou) / sigma)
else:
weights = torch.where(iou < iou_threshold, torch.ones_like(iou), torch.zeros_like(iou))
scores[i+1:] = weights * scores[i+1:]
remain = boxes[:, 4][scores > score_threshold].int()
return remain
def nms_(
boxes: Tensor, scores: Tensor,
score_threshold: float = 0.25,
iou_threshold: float = 0.45,
giou: bool = False, diou: bool = False, ciou: bool = False,
soft_nms: bool = False,
sigma: float = 0.5
):
"""
计算nms
:param boxes: 边框 [N,4] [x1,y1,x2,y2]
:param scores: 置信度 [N]
:param score_threshold: 最终返回边框时候的置信度阈值
:param iou_threshold: 在更新边框的score的时候,仅对iou超过该阈值的边框进行对应的score更新
:param giou: 是否实现GIoU-NMS/GIoU-soft-nms
:param diou: 是否实现DIoU-NMS/DIoU-soft-nms
:param ciou: 是否实现CIoU-NMS/CIoU-soft-nms
:param soft_nms: 是否实现soft-nms
:param sigma: 当更新方式为soft-nms时的超参数
:return:
"""
N = boxes.shape[0]
device = boxes.device
scores = scores.clone().to(device)
indexes = torch.arange(0, N).reshape(N, 1).to(device)
boxes = torch.cat((boxes.clone(), indexes), dim=1)
remain = []
while True:
max_score, max_index = torch.max(scores[:], dim=0)
if max_score < score_threshold:
break
remain.append(max_index)
iou = bbox_iou(boxes[max_index, :4], boxes[:, :4], xywh=False, GIoU=giou, DIoU=diou, CIoU=ciou)[:, 0]
if soft_nms:
iou = torch.where(iou < iou_threshold, torch.zeros_like(iou), iou)
weights = torch.exp(- (iou * iou) / sigma)
weights[max_index] = 0
else:
weights = torch.where(iou < iou_threshold, torch.ones_like(iou), torch.zeros_like(iou))
scores = scores * weights
return torch.tensor(remain).int()
if __name__ == '__main__':
boxes = torch.tensor([
[50, 50, 100, 100],
[52, 54, 106, 98],
[34, 32, 78, 38],
[33, 32, 76, 38]
], dtype=torch.float)
scores = torch.tensor([0.9, 0.8, 0.85, 0.86])
indexes_nms = torchvision.ops.nms(boxes, scores, 0.4)
print(indexes_nms)
print(nms(boxes, scores, iou_threshold=0.4))
print(nms(boxes, scores, iou_threshold=0.4, score_threshold=0.2, soft_nms=True))
print(nms_(boxes, scores, iou_threshold=0.4))
print(nms_(boxes, scores, iou_threshold=0.4, score_threshold=0.2, soft_nms=True))