非极大抑制nms(non-maximum suppression)的pytorch实现

4 篇文章 0 订阅
4 篇文章 0 订阅

nms在目标检测算法中作为后处理,从众多感兴趣的区域中筛选出最优的结果。

如R-CNN系列的目标检测,数千个ROI(region of interes感兴趣区域)经过模型的计算后,仍输出了不少的Bounding Box(后简称bbox)。
实际上,图像中的目标寥寥无几,一个目标很可能输出多个相互重叠的bbox。nms就是从众多交叠的bbox中提取最可信的结果。

Non-maximum suppression,译为非极大值抑制,顾名思义,就是对Confidence(置信度)并非最大的bbox进行抑制(过滤),只留下Confidence最大的bbox,就是模型最可信的结果。

流程可归纳如下:

  • 先求各bbox之间的IoU,即下面jaccard函数;
  • 对bbox的score(置信度)降序排列,即下面sort(0, descending=True)
  • 按照排序后顺序依次取出bbox,若此bbox与res_idx中任何bbox的iou小于threshold(阈值)(证明此bbox是新的目标),则加入res_idx,反之丢弃;
  • 最终根据res_idx的索引,返回最终结果。

先上结果,会有更直观的理解:

if __name__ == '__main__':
	lists = []
	threshold = 0.8
	
	# x1, y1, x2, y2, score
	lists.append([1, 1, 3, 3, 0.95])
	lists.append([1, 1, 3, 4, 0.93])
	lists.append([1, 0.9, 3.6, 3, 0.98])
	lists.append([1, 0.9, 3.5, 3, 0.97])
	lists = torch.tensor(lists)

	res = NMS(lists, threshold)
	print(lists)
	# tensor([[1.0000, 1.0000, 3.0000, 3.0000, 0.9500],
    #     	  [1.0000, 1.0000, 3.0000, 4.0000, 0.9300],
    #     	  [1.0000, 0.9000, 3.6000, 3.0000, 0.9800],
    #     	  [1.0000, 0.9000, 3.5000, 3.0000, 0.9700]])
	print(res)
	# tensor([[1.0000, 0.9000, 3.6000, 3.0000, 0.9800],
    #     	  [1.0000, 1.0000, 3.0000, 3.0000, 0.9500],
    #     	  [1.0000, 1.0000, 3.0000, 4.0000, 0.9300]])

可见score为0.97的bbox被排除了。因为它比0.98的score更小,且框的位置很接近(iou超过threshold)。


nms的实现主要分3个部分

  1. nms函数
def NMS(lists, threshold):
	"""
	lists.shape = (n, 5),n是bbox的个数
	lists[0, :] =  x1, y1, x2, y2, score
	threshold: IoU的阈值
	return: Tensor.shape(m, 5),m是过滤后bbox的个数
	"""

	# overlaps.shape: [lists.shape[0], lists.shape[0]]
	# overlaps里放的是每个bbox之间的iou
	# jaccard的实现在后面
	# 假设lists里有4个bbox
	# tensor([[1.0000, 0.6667, 0.7326, 0.7619],
    #    	  [0.6667, 1.0000, 0.5362, 0.5517],
    #    	  [0.7326, 0.5362, 1.0000, 0.9615],
    #    	  [0.7619, 0.5517, 0.9615, 1.0000]])
	overlaps = jaccard(lists, lists)

	res_idxs = []
	# lists里的元素为x1, y1, x2, y2, score
	# 这里针对score做降序排列
	# 返回排列后的结果,以及索引。这里我们只需要索引
	_, idxs = lists[:, 4].sort(0, descending=True)
	for idx in idxs:
		tag = True
		for r_idx in res_idxs:
			# 后续的bbox如果与我们想要的bbox的iou大于阈值
			# 就丢弃
			if overlaps[idx, r_idx] > threshold:
				tag = False
				break
		if tag:
			res_idxs.append(idx.item())
	# 根据索引返回最终结果
	return lists[res_idxs]
  1. jaccard函数
def jaccard(box_a, box_b):
	# box_a: x1, y1, x2, y2
	# box_b: x1, y1, x2, y2
	# 在本应用场景中,box_a和box_b是同一对象
	# intersection返回box_a、box_b的相交大小,后面会讲
	# tensor([[4.0000, 4.0000, 4.0000, 4.0000],
    #         [4.0000, 6.0000, 4.0000, 4.0000],
    #         [4.0000, 4.0000, 5.4600, 5.2500],
    #         [4.0000, 4.0000, 5.2500, 5.2500]])
	inter = intersection(box_a, box_b)

	# area_a是box_a的面积
	# box_a[:, 2] - box_a[:, 0]是宽
	# box_a[:, 3] - box_a[:, 1]是高
	# area_a.shape: inter.shape
	area_a = ((box_a[:, 2] - box_a[:, 0]) *
			  (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter)
	# 同理
	area_b = ((box_b[:, 2] - box_b[:, 0]) *
			  (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter)
	# 并集
	union = area_a + area_b - inter
	return inter / union
  1. intersection函数
def intersection(box_a, box_b):
	A = box_a.size(0)
	B = box_b.size(0)
	# bottom right 右下角点
	# br_xy.shape: [A, B, (x2, y2)]
	br_xy = torch.min(box_a[:, 2:4].unsqueeze(1).expand((A, B, 2)),
					  box_b[:, 2:4].unsqueeze(0).expand((A, B, 2)))
	# top left 左上角点
	# tl_xy.shape: [A, B, (x1, y1)]
	tl_xy = torch.max(box_a[:, 0:2].unsqueeze(1).expand((A, B, 2)),
					  box_b[:, 0:2].unsqueeze(0).expand((A, B, 2)))
	# inter.shape: [A, B, (w, h)]
	# 
	inter = torch.clamp((br_xy - tl_xy), min=0)
	# return shape(A, B)
	# w * h 交集面积
	return inter[:, :, 0] * inter[:, :, 1]
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值