#关于Cornernet当中解码的代码分析#
def _decode(
tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr,
K=100, kernel = 1, ae_threshold = 1, num_dets = 1000
):
#batch为batch_size,cat为训练集的类别数,height为用于预测的feature map的高,width为宽
batch, cat, height, width = tl_heat.size()
#tl_heat.shape =(H,W,C)使用sigmoid函数将其预测得分转换为(0-1)之间的数
tl_heat = torch.sigmoid(tl_heat)
br_heat = torch.sigmoid(br_heat)
# perform nms on heatmaps
#对热图上的数据进行nms,做max pooling操作,如果和原始位置处的值一样则保留,否则为0选出
tl_heat = _nms(tl_heat, kernel=kernel)
br_heat = _nms(br_heat, kernel=kernel)
#选出得分top-k的值,保留坐标索引、类别信息,tl_scores:左上得分,tl_inds:[1,W*H]之间的整数
#tl_clses:左上所属的类别, tl_ys:左上的纵坐标, tl_xs:左上的横坐标
tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=K)
br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=K)
#把坐标的数据进行扩充,拓展为(batch, K, K)大小,方便后续计算
tl_ys = tl_ys.view(batch, K, 1).expand(batch, K, K)
tl_xs = tl_xs.view(batch, K, 1).expand(batch, K, K)
br_ys = br_ys.view(batch, 1, K).expand(batch, K, K)
br_xs = br_xs.view(batch, 1, K).expand(batch, K, K)
#取出前K个得分所对应坐标偏移量,偏移量针对每一个点都有,只取出得分较高的那部分
if tl_regr is not None and br_regr is not None:
tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds)
tl_regr = tl_regr.view(batch, K, 1, 2)
br_regr = _tranpose_and_gather_feat(br_regr, br_inds)
br_regr = br_regr.view(batch, 1, K, 2)
#tl_regr预测的偏移量是x和y方向的,与原始的坐标相加进行坐标的更新
tl_xs = tl_xs + tl_regr[..., 0]
tl_ys = tl_ys + tl_regr[..., 1]
br_xs = br_xs + br_regr[..., 0]
br_ys = br_ys + br_regr[..., 1]
#ignoring class
# all possible boxes based on top k corners (ignoring class)
#把坐标组合成box的信息[x1,y1,x2,y2]
bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3)
#tl_tag检测到的角点的嵌入向量,只保留得分前K个的嵌入向量的值
tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds)
tl_tag = tl_tag.view(batch, K, 1)
br_tag = _tranpose_and_gather_feat(br_tag, br_inds)
br_tag = br_tag.view(batch, 1, K)
#度量左上和右下角点之间的距离
dists = torch.abs(tl_tag - br_tag)
#对得分的尺度进行扩张,方便后续的角点组合
tl_scores = tl_scores.view(batch, K, 1).expand(batch, K, K)
br_scores = br_scores.view(batch, 1, K).expand(batch, K, K)
#得到前K个左上和右下角点的平均得分
scores = (tl_scores + br_scores) / 2
#对类别的尺度进行expand
tl_clses = tl_clses.view(batch, K, 1).expand(batch, K, K)
br_clses = br_clses.view(batch, 1, K).expand(batch, K, K)
#找到左上和右下类别不一致的索引,标记为1,其余为0
cls_inds = (tl_clses != br_clses)
#找到角点之间距离大于阈值0.5的坐标索引标记为1
dist_inds = (dists > ae_threshold)
# 右下的x小于左上的x,右下的y小于左下的y,即不满足条件的坐标索引设置为1
width_inds = (br_xs < tl_xs)
height_inds = (br_ys < tl_ys)
#找出了所有不符合条件的索引,将它们的得分全部设置为-1
scores[cls_inds] = -1
scores[dist_inds] = -1
scores[width_inds] = -1
scores[height_inds] = -1
#分数的维度转换为[batch_size,K*K]
scores = scores.view(batch, -1)
#选出得分前1000的值以及它对应的索引,就是[W*H]的图上所包含的索引值
#得分前1000的所有相关信息,boxes,clses,tl_scores,br_scores
scores, inds = torch.topk(scores, num_dets)
scores = scores.unsqueeze(2)
bboxes = bboxes.view(batch, -1, 4)
bboxes = _gather_feat(bboxes, inds)
clses = tl_clses.contiguous().view(batch, -1, 1)
clses = _gather_feat(clses, inds).float()
tl_scores = tl_scores.contiguous().view(batch, -1, 1)
tl_scores = _gather_feat(tl_scores, inds).float()
br_scores = br_scores.contiguous().view(batch, -1, 1)
br_scores = _gather_feat(br_scores, inds).float()
#保留相关信息,一个数组里面包含8个数,分别表示以下信息
detections = torch.cat([bboxes, scores, tl_scores, br_scores, clses], dim=2)
return detections
Cornernet中关于解码函数(def _decode())的分析
最新推荐文章于 2021-02-13 12:23:34 发布
本文深入解析了CornerNet目标检测算法中解码模块的工作原理,详细介绍了如何从预测的热图、偏移量和嵌入向量中解码出最终的边界框,包括置信度评分、类别确定及坐标调整等关键步骤。
2994

被折叠的 条评论
为什么被折叠?



