CenterNet目标检测学习记录

模型使用

按照centernet源码官方教程执行即可,我使用的是torch 1.4.0版本,dcn好像要替换(忘了…),替换部分参考链接,我跑通demo后,去看源码,发现如果用Hourglass backbone的话,不编译dcn也能正常使用的(之前在windows怎么编译都不行),自己把代码提取出来,定义自己的模型是hourglass backbone就行了(默认是dla34).

训练自己的数据

训练自己的参考这篇博客链接

模型详解

1.hourglass网络

参考这两篇博客,介绍得比较详细链接1链接2,houglass之间的连接部分如下.(每个hourglass都有输出,下图第二个hourglass边幅原因,没画出来,计算损失时要计算每个输出,预测时只取最后一个输出)
在这里插入图片描述hm [batch,ncls,H,W].wh和reg [batch,2,H,W],所以,如果两个物体中心刚好重合,实际只能得到一个预测框,不过作者也说了,coco数据集上这种情况不到千分一,就没处理.

2.heatmap

centernet里面有很多设置,我基本都使用默认设置

1)半径的确定

半径的大小确定来源cornernet,主要是cornernet靠近gt角点的目标框与标签还是有很高的IOU,因此一定范围内的损失权重跟远的负样本不一致.这个范围就是通过IOU来确定半径得到的邻域,半径的确定参考
链接(此处链接结果貌似不太对,看原理即可),严格来说不是用r来计算,应该是rcos/rsin,不过影响不大
但是官方源码那里是有问题的,参考链接,链接
修正如下

def gaussian_radius(det_size, min_overlap=0.7):
  height, width = det_size

  a1  = 1
  b1  = (height + width)
  c1  = width * height * (1 - min_overlap) / (1 + min_overlap)
  sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)

  #r1  = (b1 + sq1) / 2 #源代码是错的
  r1 = (b1 - sq1) / (2 * a1)

  a2  = 4
  b2  = 2 * (height + width)
  c2  = (1 - min_overlap) * width * height
  sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
  #r2  = (b2 + sq2) / 2
  r2 = (b2 - sq2) / (2 * a2)

  a3  = 4 * min_overlap
  b3  = -2 * min_overlap * (height + width)
  c3  = (min_overlap - 1) * width * height
  sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
  #r3  = (b3 + sq3) / 2
  r3 = (b3 + sq3) / (2 * a3)
  return min(r1, r2, r3)

2)heatmap绘制

在这里插入图片描述
x,y值是相对圆心位置,centernet内 一种 sigma取值如下
在这里插入图片描述

3.损失

1)论文上的heatmap损失

在这里插入图片描述作者采用alpha=2,beta=4

def _neg_loss(pred, gt):
  ''' Modified focal loss. Exactly the same as CornerNet.
      Runs faster and costs a little bit more memory
    Arguments:
      pred (batch x c x h x w)
      gt_regr (batch x c x h x w)
  '''
  pos_inds = gt.eq(1).float()#equal比较函数,正例
  neg_inds = gt.lt(1).float()#反例

  neg_weights = torch.pow(1 - gt, 4)# beta=4 ,alpha=2

  loss = 0

  pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
  neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds

  num_pos  = pos_inds.float().sum()
  pos_loss = pos_loss.sum()
  neg_loss = neg_loss.sum()

  if num_pos == 0:
    loss = loss - neg_loss
  else:
    loss = loss - (pos_loss + neg_loss) / num_pos
  return loss

  1. wh,和偏移量 reg 默认是L1loss
    在这里插入图片描述
    在这里插入图片描述
class RegL1Loss(nn.Module):
  def __init__(self):
    super(RegL1Loss, self).__init__()
  
  def forward(self, output, mask, ind, target):
    pred = _transpose_and_gather_feat(output, ind) #[batch,maxobjs,2]
    mask = mask.unsqueeze(2).expand_as(pred).float()
    # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
    loss = F.l1_loss(pred * mask, target * mask, size_average=False)
    loss = loss / (mask.sum() + 1e-4)
    return loss

总损失
在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值