前言
NSM非极大值抑制的代码实现
一、NMS是什么?
对于多分类的目标检测,依次删除每一类中得分概率较低以及得分概率较高但是可被替代的Boxes(有更高得分概率的框与该框重合度较高)
二、代码实现
def decode_single(self, bboxes_in, scores_in, criteria, max_output, max_num=200):
"""
decode:
input : bboxes_in (Tensor 8732 x 4), scores_in (Tensor 8732 x nitems)
output : bboxes_out (Tensor nboxes x 4), labels_out (Tensor nboxes)
criteria : IoU threshold of bboexes
max_output : maximum number of output bboxes
"""
# Reference to https://github.com/amdegroot/ssd.pytorch
bboxes_out = []
scores_out = []
labels_out = []
# 非极大值抑制算法
# scores_in (Tensor 8732 x nitems), 遍历返回每一列数据,即8732个目标的同一类别的概率
for i, score in enumerate(scores_in.split(1, 1)):
# skip background
if i == 0:
continue
# [8732, 1] -> [8732]
score = score.squeeze(1)
# 虑除预测概率小于0.05的目标
mask = score > 0.05
bboxes, score = bboxes_in[mask, :], score[mask]
if score.size(0) == 0:
continue
# 按照分数从小到大排序
score_sorted, score_idx_sorted = score.sort(dim=0)
# select max_output indices
score_idx_sorted = score_idx_sorted[-max_num:]
candidates = []
while score_idx_sorted.numel() > 0:
idx = score_idx_sorted[-1].item()
# 获取排名前score_idx_sorted名的bboxes信息 Tensor:[score_idx_sorted, 4]
bboxes_sorted = bboxes[score_idx_sorted, :]
# 获取排名第一的bboxes信息 Tensor:[4]
bboxes_idx = bboxes[idx, :].unsqueeze(dim=0)
# 计算前score_idx_sorted名的bboxes与第一名的bboxes的iou
iou_sorted = calc_iou_tensor(bboxes_sorted, bboxes_idx).squeeze()
# we only need iou < criteria
# 丢弃与第一名iou > criteria的所有目标(包括自己本身)
score_idx_sorted = score_idx_sorted[iou_sorted < criteria]
# 保存第一名的索引信息
candidates.append(idx)
# 保存该类别通过非极大值抑制后的目标信息
bboxes_out.append(bboxes[candidates, :]) # bbox坐标信息
scores_out.append(score[candidates]) # score信息
labels_out.extend([i] * len(candidates)) # 标签信息
if not bboxes_out: # 如果为空的话,返回空tensor,注意boxes对应的空tensor size,防止验证时出错
return [torch.empty(size=(0, 4)), torch.empty(size=(0,), dtype=torch.int64), torch.empty(size=(0,))]
bboxes_out = torch.cat(bboxes_out, dim=0).contiguous()
scores_out = torch.cat(scores_out, dim=0).contiguous()
labels_out = torch.as_tensor(labels_out, dtype=torch.long)
# 对所有目标的概率进行排序(无论是什么类别),取前max_num个目标
_, max_ids = scores_out.sort(dim=0)
max_ids = max_ids[-max_output:]
return bboxes_out[max_ids, :], labels_out[max_ids], scores_out[max_ids]
总结
基础的NMS实现,面试可能会考察。了解相关具体实现,也能学到更多基础原理