CornerNet和loss、解码相关的函数其实在kp.py和kp_utils.py里面
解码函数如下所示:
def _decode(
tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr,
K=100, kernel=1, ae_threshold=1, num_dets=1000
):
batch, cat, height, width = tl_heat.size()
tl_heat = torch.sigmoid(tl_heat)
br_heat = torch.sigmoid(br_heat)
# perform nms on heatmaps
"""
其实就是对概率图进行maxpooling
"""
tl_heat = _nms(tl_heat, kernel=kernel)
br_heat = _nms(br_heat, kernel=kernel)
tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=K)
br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=K)
"""
tl_ys,tl_xs原本的shape为[batch,K]
"""
tl_ys = tl_ys.view(batch, K, 1).expand(batch, K, K)
tl_xs = tl_xs.view(batch, K, 1).expand(batch, K, K)
br_ys = br_ys.view(batch, 1, K).expand(batch,