前面博客讲述了训练从数据流的获取到Loss的选择,但是在我们的检测的时候,要进行NMS然后再进行输出,这篇博客主要讲述NMS。
1、由于网络结构预测的是Box的offsets,即首先要把offsets应用于anchor的转换,进行bbox_decoder:
def __init__(self, stds=(0.1, 0.1, 0.2, 0.2), means=(0., 0., 0., 0.),
convert_anchor=False, clip=None):
super(NormalizedBoxCenterDecoder, self).__init__()
assert len(stds) == 4, "Box Encoder requires 4 std values."
self._stds = stds
self._means = means
self._clip = clip
if convert_anchor:
self.corner_to_center = BBoxCornerToCenter(split=True)
else:
self.corner_to_center = None
def hybrid_forward(self, F, x, anchors):
if self.corner_to_center is not None:
a = self.corner_to_center(anchors)
else:
a = anchors.split(axis=-1, num_outputs=4)
p = F.split(x, axis=-1, num_outputs=4)
ox = F.broadcast_add(F.broadcast_mul(p[0] * self._stds[0] + self._means[0], a[2]), a[0])
oy = F.broadcast_add(F.broadcast_mul(p[1] * self._stds[1] + self._means[1], a[3]), a[1])
tw = F.exp(p[2] * self._stds[2] + self._means[2])
th = F.exp(p[3] * self._stds[3] + self._means[3])
if self._clip:
tw = F.minimum(tw, self._clip)
th = F.minimum(th, self._clip)
ow = F.broadcast_mul(tw, a[2]) / 2
oh = F.broadcast_mul(th, a[3]) / 2
return F.concat(ox - ow, oy - oh, ox + ow, oy + oh, dim=-1)
2、进行完bbox_decoder后要进行class_decoder
class MultiPerClassDecoder(gluon.HybridBlock):
def __init__(self, num_class, axis=-1, thresh=0.01):
super(MultiPerClassDecoder, self).__init__()
self._fg_class = num_class - 1
self._axis = axis
self._thresh = thresh
def hybrid_forward(self, F, x):
scores = x.slice_axis(axis=self._axis, begin=1, end=None) # b x N x fg_class
template = F.zeros_like(x.slice_axis(axis=-1, begin=0, end=1))
cls_ids = []
for i in range(self._fg_class):
cls_ids.append(template + i) # b x N x 1
cls_id = F.concat(*cls_ids, dim=-1) # b x N x fg_class
mask = scores > self._thresh
cls_id = F.where(mask, cls_id, F.ones_like(cls_id) * -1)
scores = F.where(mask, scores, F.zeros_like(scores))
return cls_id, scores
3、进行NMS
result = F.concat(*results, dim=1)
if self.nms_thresh > 0 and self.nms_thresh < 1:
result = F.contrib.box_nms(
result, overlap_thresh=self.nms_thresh, topk=self.nms_topk, valid_thresh=0.01,
id_index=0, score_index=1, coord_start=2, force_suppress=False)
if self.post_nms > 0:
result = result.slice_axis(axis=1, begin=0, end=self.post_nms)
ids = F.slice_axis(result, axis=2, begin=0, end=1)
scores = F.slice_axis(result, axis=2, begin=1, end=2)
bboxes = F.slice_axis(result, axis=2, begin=2, end=6)