CenterNet:Objects as Points

CenterNet论文链接

一.背景

1.anchor-base缺点         

(1).anchor的设置对结果影响很大,不同项目这些超参都需要根据经验来确定,难度较大.

(2).anchor太过密集,其中很多是负样本,引入了不平衡.

(3).anchor的计算涉及IOU增加计算复杂度.

2.应用场景

(1).目标检测

(2).3D定位

(3).人体姿态估计

二.网络介绍

输出分支主要由三部分组成

(1)heatmap,大小为(W/4,H/4,C),输出不同类别的物体中心点

(2)offset,大小为(W/4,H/4,2)输出中心点偏移

(3)Height&Weight大小为(W/4,H/4,2),输出中心点检测框的宽高

1.思想

通过预测出目标的heatmap,找出heatmap的峰值就是目标的中心点.

heatmap高斯核半径制作参考这篇文章,和这篇文章。

代码:


import numpy as np
np.set_printoptions(suppress=True)#设置小数显示

def gaussian_radius(det_size, min_overlap=0.7):
    box_w, box_h = det_size
    a1 = 1
    b1 = (box_w + box_h)
    c1 = box_w * box_h * (1 - min_overlap) / (1 + min_overlap)
    sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
    r1 = (b1 + sq1) /2# (2*a1) # (2*a1)

    a2 = 4
    b2 = 2 * (box_w + box_h)
    c2 = (1 - min_overlap) * box_w * box_h
    sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
    r2 = (b2 + sq2) / 2# (2*a2)  # (2*a2)

    a3 = 4 * min_overlap
    b3 = -2 * min_overlap * (box_w + box_h)
    c3 = (min_overlap - 1) * box_w * box_h
    sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
    print('==b3 + sq3:', b3 + sq3)
    print('====a3:===', a3)
    r3 = (b3 + sq3) / 2#(2*a3)  # (2*a3)
    print('==r1, r2, r3:', r1, r2, r3)
    return min(r1, r2, r3)

gt_numpy = np.zeros((512 // 4, 512 // 4, 3)).astype(np.float32)
box_w_s, box_h_s = 100 / 4, 80 / 4
r = gaussian_radius([box_w_s, box_h_s])
sigma_w = sigma_h = r / 3
# create Gauss heatmap
print('===sigma_w:', sigma_w)
ws = 512 / 4
hs = 512 / 4
grid_x = 64
grid_y = 64
gt_cls = 0
gt_numpy[grid_y, grid_x, gt_cls] = 1
for i in range(grid_x - 3 * int(sigma_w), grid_x + 3 * int(sigma_w) + 1):
    for j in range(grid_y - 3 * int(sigma_h), grid_y + 3 * int(sigma_h) + 1):
        if i < ws and j < hs:
            v = np.exp(
                - (i - grid_x) ** 2 / (2 * sigma_w ** 2) - (j - grid_y) ** 2 / (2 * sigma_h ** 2))
            pre_v = gt_numpy[j, i, int(gt_cls)]
            gt_numpy[j, i, 0] = max(v, pre_v)
print('===gt_numpy.shape:', gt_numpy.shape)

middle_gt = gt_numpy[(64 - 3 * int(sigma_h)):(64 + 3*int(sigma_h) + 1),
             (64 - 3 * int(sigma_w)):(64 + 3 * int(sigma_w)+1), 0]
print(type(middle_gt))
print(np.around(middle_gt, 2))
out_img = gt_numpy[..., 0]*255.
cv2.imwrite('./out_img.jpg', out_img)
import cv2
warped_color = cv2.applyColorMap(out_img.astype(np.uint8), cv2.COLORMAP_JET)
cv2.imwrite('./out_img_color.jpg', warped_color)

   

2.与anchor based区别

(1).不需要阈值区分前后景;

(2).一个目标只需要一个heatmap,避免使用nms,heatmap的峰值就是目标中心点;

(3).下采样步长小只是4,减少了需要多个重复框.

3.heatmap和相应focal loss(分类)

heatmap就是目标的热力图,通道数就是类别数,loss采用focal loss,其按照高斯分布来进行分配,因为除了中心点的heatmap其实没必要完全贡献loss.

Ŷxyc:每个通道预测的heatmap,(x,y)处的值.

Yxyc:每个通道的gt heatmap,(x,y)处的值,服从高斯分布.

α,β: 超参用来控制loss.

N:图片所有的关键点.

pytorch代码示例:


import torch

def modified_focal_loss(pred, gt, alpha, beta):
    """
    focal loss copied from CenterNet, modified version focal loss
    change log: numeric stable version implementation
    """
    pos_inds = gt.eq(1).float()
    neg_inds = gt.lt(1).float()

    neg_weights = torch.pow(1 - gt, beta)
    # clamp min value is set to 1e-12 to maintain the numerical stability
    pred = torch.clamp(pred, 1e-12)

    pos_loss = torch.log(pred) * torch.pow(1 - pred, alpha) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, alpha) * neg_weights * neg_inds

    num_pos = pos_inds.float().sum()
    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()
    print('===num_pos:', num_pos)
    if num_pos == 0:
        loss = -neg_loss
    else:
        loss = -(pos_loss + neg_loss) / num_pos
    # print(f'num_pos {num_pos},pos_loss {pos_loss},neg_loss {neg_loss}')
    return loss


if __name__ == '__main__':
    b, c, h, w = 4, 10, 224, 224
    pred = torch.rand(b, c, h, w)

    b, c, h, w = 4, 10, 224, 224
    gt = torch.clamp(torch.rand(b, c, h, w)+0.1, 0., 1.0)

    print('==pred.shape:', pred.shape)
    print('==gt.shape:', gt.shape)
    loss = modified_focal_loss(pred, gt, alpha=2, beta=4)
    print('=loss:', loss)

4.offset loss(L1)

用offests来矫正下采样造成的检测框偏移,从而让检测框更加紧凑.

   p是key point,R是下采样倍数,这里从预测图的heatmap恢复到原图就会有精度损失,严重影响小物体,所以就通过一个网络分支去学习这种误差.

5.回归loss(L1)

采用L1 loss回归宽高

6.总loss

loss由三部分组成:heatmap分类loss,回归宽高loss,回归偏移loss.

输出类别数+4(宽高,中心点偏移).

7.推理

在heatmap上通过8近邻取得前100个峰值,在对8近邻的点3*3 maxpooling获得中心点,在与预测的宽高,偏移量组合就得出检测框.

:预测的中心点

:预测的中心点偏移量

:预测宽高

 

三.实验结果

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值