Anchor Free系列模型03

2021SC@SDUSC
loss计算代码

损失函数
heatmap loss
输入图像在这里插入图片描述,W为图像宽度,H为图像高度。网络输出的关键点热图heatmap为在这里插入图片描述其中,R代表得到输出相对于原图的步长stride。C代表类别个数。
下面是CenterNet中核心loss公式:
在这里插入图片描述
这个和Focal loss形式很相似,和是超参数,N代表的是图像关键点个数。

在这里插入图片描述的时候,
对于易分样本来说,预测值在这里插入图片描述接近于1,在这里插入图片描述就是一个很小的值,这样loss就很小,起到了矫正作用。

对于难分样本来说,预测值在这里插入图片描述接近于0,在这里插入图片描述就比较大,相当于加大了其训练的比重。

otherwise的情况下:
在这里插入图片描述
上图是一个简单的示意,纵坐标是在这里插入图片描述,分为A区(距离中心点较近,但是值在0-1之间)和B区(距离中心点很远接近于0)。
对于A区来说,由于其周围是一个高斯核生成的中心,的值是从1慢慢变到0。
举个例子(CenterNet中默认α=2,β=4):
在这里插入图片描述
总结一下:为了防止预测值在这里插入图片描述过高接近于1,所以用在这里插入图片描述来惩罚Loss。而在这里插入图片描述这个参数距离中心越近,其值越小,这个权重是用来减轻惩罚力度。

对于B区来说,的预测值在这里插入图片描述理应是0,如果该值比较大比如为1,那么在这里插入图片描述作为权重会变大,惩罚力度也加大了。如果预测值接近于0,那么在这里插入图片描述会很小,让其损失比重减小。对于在这里插入图片描述来说,B区的值比较大,弱化了中心点周围其他负样本的损失比重。

代码解析

# 得到heat map, reg, wh 三个变量
hmap, regs, w_h_ = zip(*outputs)

regs = [
_tranpose_and_gather_feature(r, batch['inds']) for r in regs
]
w_h_ = [
_tranpose_and_gather_feature(r, batch['inds']) for r in w_h_
]

# 分别计算loss
hmap_loss = _neg_loss(hmap, batch['hmap'])
reg_loss = _reg_loss(regs, batch['regs'], batch['ind_masks'])
w_h_loss = _reg_loss(w_h_, batch['w_h_'], batch['ind_masks'])

# 进行loss加权,得到最终loss
loss = hmap_loss + 1 * reg_loss + 0.1 * w_h_loss

上述transpose_and_gather_feature函数具体实现如下,主要功能是将ground truth中计算得到的对应中心点的值获取。

def _tranpose_and_gather_feature(feat, ind):
  # ind代表的是ground truth中设置的存在目标点的下角标
  feat = feat.permute(0, 2, 3, 1).contiguous()# from [bs c h w] to [bs, h, w, c] 
  feat = feat.view(feat.size(0), -1, feat.size(3)) # to [bs, wxh, c]
  feat = _gather_feature(feat, ind)
  return feat

def _gather_feature(feat, ind, mask=None):
  # feat : [bs, wxh, c]
  dim = feat.size(2)
  # ind : [bs, index, c]
  ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
  feat = feat.gather(1, ind) # 按照dim=1获取ind
  if mask is not None:
    mask = mask.unsqueeze(2).expand_as(feat)
    feat = feat[mask]
    feat = feat.view(-1, dim)
  return feat
def _neg_loss(preds, targets):
    ''' Modified focal loss. Exactly the same as CornerNet.
        Runs faster and costs a little bit more memory
        Arguments:
        preds (B x c x h x w)
        gt_regr (B x c x h x w)
    '''
    pos_inds = targets.eq(1).float()# heatmap为1的部分是正样本
    neg_inds = targets.lt(1).float()# 其他部分为负样本

    neg_weights = torch.pow(1 - targets, 4)# 对应(1-Yxyc)^4

    loss = 0
    for pred in preds: # 预测值
        # 约束在0-1之间
        pred = torch.clamp(torch.sigmoid(pred), min=1e-4, max=1 - 1e-4)
        pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
        neg_loss = torch.log(1 - pred) * torch.pow(pred,
                                                   2) * neg_weights * neg_inds
        num_pos = pos_inds.float().sum()
        pos_loss = pos_loss.sum()
        neg_loss = neg_loss.sum()

        if num_pos == 0:
            loss = loss - neg_loss # 只有负样本
        else:
            loss = loss - (pos_loss + neg_loss) / num_pos
    return loss / len(preds)

在这里插入图片描述
代码和以上公式一一对应,pos代表正样本,neg代表负样本。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值