Cornernet中关于解码函数(def _decode())的分析

本文深入解析了CornerNet目标检测算法中解码模块的工作原理,详细介绍了如何从预测的热图、偏移量和嵌入向量中解码出最终的边界框,包括置信度评分、类别确定及坐标调整等关键步骤。
摘要由CSDN通过智能技术生成
#关于Cornernet当中解码的代码分析#

def _decode(
    tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr, 
    K=100, kernel = 1, ae_threshold = 1, num_dets = 1000
):
#batch为batch_size,cat为训练集的类别数,height为用于预测的feature map的高,width为宽
    batch, cat, height, width = tl_heat.size()
#tl_heat.shape =(H,W,C)使用sigmoid函数将其预测得分转换为(0-1)之间的数
    tl_heat = torch.sigmoid(tl_heat)
    br_heat = torch.sigmoid(br_heat)

    # perform nms on heatmaps
#对热图上的数据进行nms,做max pooling操作,如果和原始位置处的值一样则保留,否则为0选出
    tl_heat = _nms(tl_heat, kernel=kernel)
    br_heat = _nms(br_heat, kernel=kernel)
	
#选出得分top-k的值,保留坐标索引、类别信息,tl_scores:左上得分,tl_inds:[1,W*H]之间的整数
#tl_clses:左上所属的类别, tl_ys:左上的纵坐标, tl_xs:左上的横坐标
    tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=K)
    br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=K)
	
#把坐标的数据进行扩充,拓展为(batch, K, K)大小,方便后续计算
    tl_ys = tl_ys.view(batch, K, 1).expand(batch, K, K)
    tl_xs = tl_xs.view(batch, K, 1).expand(batch, K, K)
    br_ys = br_ys.view(batch, 1, K).expand(batch, K, K)
    br_xs = br_xs.view(batch, 1, K).expand(batch, K, K)
	
#取出前K个得分所对应坐标偏移量,偏移量针对每一个点都有,只取出得分较高的那部分
    if tl_regr is not None and br_regr is not None:
        tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds)
        tl_regr = tl_regr.view(batch, K, 1, 2)
        br_regr = _tranpose_and_gather_feat(br_regr, br_inds)
        br_regr = br_regr.view(batch, 1, K, 2)
		
#tl_regr预测的偏移量是x和y方向的,与原始的坐标相加进行坐标的更新
        tl_xs = tl_xs + tl_regr[..., 0]
        tl_ys = tl_ys + tl_regr[..., 1]
        br_xs = br_xs + br_regr[..., 0]
        br_ys = br_ys + br_regr[..., 1]
		
#ignoring class
# all possible boxes based on top k corners (ignoring class)
#把坐标组合成box的信息[x1,y1,x2,y2]
    bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3)
	
#tl_tag检测到的角点的嵌入向量,只保留得分前K个的嵌入向量的值
    tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds)
    tl_tag = tl_tag.view(batch, K, 1)
    br_tag = _tranpose_and_gather_feat(br_tag, br_inds)
    br_tag = br_tag.view(batch, 1, K)
	
#度量左上和右下角点之间的距离
    dists  = torch.abs(tl_tag - br_tag)
	
#对得分的尺度进行扩张,方便后续的角点组合
    tl_scores = tl_scores.view(batch, K, 1).expand(batch, K, K)
    br_scores = br_scores.view(batch, 1, K).expand(batch, K, K)
	
#得到前K个左上和右下角点的平均得分
    scores    = (tl_scores + br_scores) / 2

#对类别的尺度进行expand
    tl_clses = tl_clses.view(batch, K, 1).expand(batch, K, K)
    br_clses = br_clses.view(batch, 1, K).expand(batch, K, K)
	
#找到左上和右下类别不一致的索引,标记为1,其余为0
    cls_inds = (tl_clses != br_clses)

#找到角点之间距离大于阈值0.5的坐标索引标记为1
    dist_inds = (dists > ae_threshold)

# 右下的x小于左上的x,右下的y小于左下的y,即不满足条件的坐标索引设置为1
    width_inds  = (br_xs < tl_xs)
    height_inds = (br_ys < tl_ys)
	
#找出了所有不符合条件的索引,将它们的得分全部设置为-1
    scores[cls_inds]    = -1
    scores[dist_inds]   = -1
    scores[width_inds]  = -1
    scores[height_inds] = -1
	
#分数的维度转换为[batch_size,K*K]
    scores = scores.view(batch, -1)
	
#选出得分前1000的值以及它对应的索引,就是[W*H]的图上所包含的索引值
#得分前1000的所有相关信息,boxes,clses,tl_scores,br_scores
    scores, inds = torch.topk(scores, num_dets)
    scores = scores.unsqueeze(2)

    bboxes = bboxes.view(batch, -1, 4)
    bboxes = _gather_feat(bboxes, inds)

    clses  = tl_clses.contiguous().view(batch, -1, 1)
    clses  = _gather_feat(clses, inds).float()

    tl_scores = tl_scores.contiguous().view(batch, -1, 1)
    tl_scores = _gather_feat(tl_scores, inds).float()
    br_scores = br_scores.contiguous().view(batch, -1, 1)
    br_scores = _gather_feat(br_scores, inds).float()
	
#保留相关信息,一个数组里面包含8个数,分别表示以下信息
    detections = torch.cat([bboxes, scores, tl_scores, br_scores, clses], dim=2)
    return detections
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值