浅入浅出谈CenterNet

本文介绍了CenterNet论文的核心思想,它是一种无Anchor的目标检测模型,通过预测物体中心点、检测框尺寸和坐标补偿来确定物体位置。相比于Anchor-Based模型,CenterNet简化了流程,减少了预设检测框的步骤,但也存在如物体重叠时无法区分的问题。损失函数采用了FocalLoss以应对正负样本不平衡。文章还探讨了GroundTruth的设计和目标检测的基本概念。
摘要由CSDN通过智能技术生成

CenterNet论文的名称为:Objects as Points[1],即将“物体”视为“点”。一语道破了这个模型的核心方法,即通过模型输出的特征图上某一点的信息,直接推断此位置某相关物体的信息(i.e. 检测框信息)。

原文给出了CenterNet很多应用,如3D目标检测、人体姿势检测等,由于方法大同小异,本文只给出2D目标检测的方式方法。

本文主要分为六个部分:什么是目标检测、什么是Anchor、Anchor-Based模型缺点、CenterNet是如何预测结果的、Ground Truth是如何设计的以及损失函数的设计。

一、 什么是目标检测

由于默认本文读者为有一定基础的计算机视觉爱好者,因此此节不做过多细节描述。目标检测即通过计算机视觉技术,对输入图像进行处理,得到我们感兴趣的物体的位置以及类别。位置一般会用图像坐标系中的坐标(i.e. 实数)表示,类别则一般由提前定义好的类别索引(i.e. 自然数)表示。因此,目标检测问题其实已经同时触及了机器学习中的两大类:回归与分类。回归实数坐标+分类离散类别,就是目标检测任务抽象后的核心要素。

二、什么是Anchor

Anchor是目标检测中一个很重要的概念,我们可以理解成提前预设好的检测框。基于是否提前设定Anchor,我们又可以将目标检测模型分为Anchor-Based模型和Anchor-Free模型。Anchor-Based模型最著名的就是RCNN[2],而最初的Anchor-Free模型则是大名鼎鼎的Yolo[3]。我们一般认为,Anchor-Free模型胜在速度,而精准度上则普遍不如Anchor-Based模型。胜在速度的原因很显然,因为我们需要用一些算法(e.g. selective search,RPN)去花时间筛选并预设这些Anchors,因此这类网络也叫做两阶段模型(Two-Stage)。相对地,不需要单独计算Anchor的称为一阶段模型。

注意,不是每个Anchor-Based模型都需要去单独使用算法预设Anchors,比如YoloV2[4],这个模型只需要预设若干给定长宽的Anchors,然后去回归补偿(offset)就好。

至于精准度上为什么Anchor-Based模型更优?这个原因看似很显然,但实际平不算平凡(trivial)。我们可以显然的理解为,有独立的、额外的算法/模型去计算检测框,当然会使检测框形状、位置预测的更加精准。但理论上只要能够学到足够的特征,就像yolo一样,依然可以直接回归出较为准确的检测框。这个问题我查了很多资料,没有很明确的解答。

关于这个问题,论文[4]中提到,对于卷积网络而言,预测补偿比直接预测检测框坐标对于模型来说,是更容易学到的,因此YoloV2才引入了Anchors。另外,个人理解,Anchor-Free模型的目标函数图像会比引入先验知识的Anchor-Based模型复杂很多,因此收敛也会遇到很大问题。由于Yolo发表较早,因此学习起来会比Anchor-Based模型困难很多,有可能会导致准确度有所降低。

更多关于这两类模型的具体介绍,可以参考论文[5]。

三、Anchor-Based模型的缺点

首先,如何定义这些Anchors的大小和长宽比?如上所说,一般是使用一些传统计算机视觉方法(Selective search,K聚类)或者神经网络(RPN)。但不管用哪种方法,都会使得模型的推理速度变慢。另外,Anchors的数量选取也是一个问题,如果选的少,那么我们可能很难覆盖所有可能性,就会造成准确度的下降;相反地,如果选的多,那么会造成两个问题:1. 推理速度变慢,因为每个Anchor都会送入卷积网络中进行分类;2. 会造成正样本和负样本的严重失衡。

简而言之,就是Anchors虽然可以让模型更简单、精准地学习到想要的输出,但同时也对模型的能力有了一定的限制,使模型很难泛化到更极端的情况下。这也是为什么越来越多的研究人员开始对Anchor-Free模型感兴趣。

四、CenterNet是如何预测检测框的?

回到主题中来,CenterNet是目前比较经典的一个Anchor-Free模型,其演变于之前发表的一个Anchor-Free模型——CornerNet[6],因此在CenterNet中依然可以看到CornerNet的影子。

CenterNet核心原理,图源[1]

首先,要想了解一个模型在做什么,最重要的就是了解这个模型的输出是什么。CenterNet输出是一个三维特征图,维度为(C+4, H,W),其中C为要预测的类别数,H为高,W为宽。你也可以理解成是一个H*W的二维特征图,而图上每个点都是一个C+4维向量。

我们先来说一下C+4分别代表什么,上面已经说了,C代表类别数,也就是说,对于每类物体,我们都有一个H*W的二维特征图,这个图我们一般称之为keypoints heatmap,顾名思义,他是一个热图,“热度值”高的点代表关键点,而在CenterNet中,关键点就是该物体的检测框的中心点。

也就是说:

  1. 对于每类物体,我们都有一个heatmap,取一个阈值,heatmap上高于这个阈值的点(采用局部最大值对比方法),我们视之为关键点,也就是说在这个点附近是有这类的一个物体的,且这个物体的检测框中心点就是这个关键点;

  1. 假设我们的认为是定位所有的猫和狗,那么我们的C=2。 若一张图片上,有2只小狗,3只小猫,那么共会生成2个heatmap,如果我们的模型训练足够准确,阈值取的足够好,那么我们的两个heatmap上应该一个有2个关键点(狗),一个有3个关键点(猫)且这五个关键点所在位置均为这五只猫狗的中心点附近;

现在已经有了物体的中心点,那么如何生成检测框呢?这就用到了C+4个特征图中的后四个特征图了。这四个特征图上的数值,分别代表检测框高度(h),检测框宽度(w),检测框中心点横坐标补偿(x_offset),检测框中心点纵坐标补偿(y_offset)。

h和w是非常好理解的,在这里不赘述。那么坐标补偿是做什么用的呢?由于我们在生成特征图的时候需要通过下采样来提高感受野,因此假设我们的输入图片大小为640*480,可能我们的输出特征图大小仅为160*120(i.e. 下采样因子为4)。也就是说,如果原图中,某物体的中心点位置为(123,123),那么在特征图中,他的横纵坐标就是123/4下取整为30。那么如果我们再还原回去,该点的坐标即为(30*4, 30*4)=(120, 120)。这样的话我们在横纵坐标上都有3像素的偏移。为了让预测结果更准,我们因此加入了坐标补偿的回归值,通过模型自学习,输出通过下采样->还原过程中产生的那部分偏移,以对结果进行更精准地校正。

细心的读者可能已经发现,检测框高度、宽度、横坐标补偿、纵坐标补偿这四个特征图,并不是为每一类物体而服务的,而是为全体物体服务的(因为不是4*C而是4+C)。这也是CenterNet一个潜在的缺点,就是如果有2个物体的中心点(不管是同类还是异类)完全重合,那么CenterNet将不会输出其中一个物体。当然,完全重合的概率是非常低的,因此CenterNet在不同数据集上的表现依旧很鲁棒。

总结一下:模型的输出可以分为三大检测头:1. 中心点;2. 检测框长宽;3. 中心点坐标补偿。共计C+4个二维特征图,在C个特征图(heatmap)上,我们通过局部最大值筛选法,筛出高于某个阈值的局部最大点,这些点就是对应类别上的物体中心点。在得到这些点的坐标后,我们在后四个特征图(i.e. 检测框高度、宽度、横坐标补偿、纵坐标补偿)上分别取这些点位置的4维向量,还原出这些点对应的检测框。

五、Ground Truth(gt)是如何设计的

gt的编码基本与模型输出基本一致,就是将中心点坐标、检测框长宽、检测框坐标补偿均生成二维图像。但是里面有一些小细节,需要解释一下。

  1. gt的维度大小不等于输入图像,而是等于输入图像除以下采样因子,也就是说,如果输入图像维度为640*480,下采样因子为4,中心点坐标为(123,123),那么gt的维度即为160*120,中心点坐标为(30,30),补偿值为(123-(30*4))/4=0.75。

  1. 我们应该如何设计中心点的heatmap?我们能想到的最平凡的方法就是生成一个二维矩阵,除了关键点坐标位置上的数值为1,其他均为0。但是这会出现一个问题,就是某些关键点附近的点,只要不是关键点本身,值都为0,那么在算loss的时候模型都会对他进行惩罚。这不是我们想看到的,因为即使他并不是恰好是关键点坐标,但只要离关键点足够近,依然是可以接受的。

因此我们想到了第二种方法,关键点附近的某个圆内,值均为1,其他均为0。这样做的目的是,我们将与关键点周围某个“可以接受的范围内”的点,均看作关键点的,在算loss的时候不进行惩罚。但是这样做依然会出现问题,比如距离关键点1个像素位的点,与距离关键点10个像素位的点,有可能数值均为1,那么均不会被惩罚,但是我们希望模型的输出尽可能地接近关键点。

最终我们采取的方法是在关键点周围生成高斯圆。也就是生成一个中心值为1,随着距离中心越来越远,数值越来越小的圆。这个圆的数值服从二维离散高斯分布。但是这里面牵扯到一个参数的选取,那就是方差/标准差,也可以认为是这个圆的半径大小,因为两者是等价的。具体公式可以参考[7]。原理简单来说就是遵从“检测框越大,圆半径越大”原则进行设计。

具体方法可以描述为,通过不断平移gt检测框,计算IoU,寻找IoU大于某个阈值的边界情况下的检测框顶点值,这些边界情况的顶点值距离原检测框对应顶点值的距离,就可以认为是该圆半径,从而计算出高斯圆所需要的方差值。

  1. 对于w、h的heatmap,我们并不会去做高斯圆的操作,我们只需要在关键点位置赋予对应的检测框长宽即可,但是在预测的时候,模型依然会预测其他位置的长宽,只不过我们在训练算loss的时候,其他位置的长宽并不会参与计算,因此并不影响。

另外,我们在对中心点坐标及补偿进行gt编码时,是除以了下采样因子,也就是这些值都是在“输出空间”(e.g. 160*120)里的,但对w、h进行编码时,不需要除以下采样因子,只需要将“输入空间”(e.g. 640*480)中的数值填入相应位置即可。

六、损失函数

这个模型的损失函数也比较简单,检测框长宽、坐标补偿的损失函数均为mean absolute error(MAE)。至于关键点的损失函数,采用的是最近几年在目标检测中非常受欢迎的Focal Loss。这个损失函数最初是RetinaNet[8]的作者提出的,动机是为了解决目标检测中经常出现的正负样本不平衡现象。核心思想是通过加了一个与预测值相关的权重系数,使得模型对误分类“严重”的样本加大惩罚。这里不对Focal Loss进行细致讲解,有兴趣的可以去看一下[9]。

CenterNet采用了Focal Loss,但进行了一个小优化。下图为关键点损失函数,其中带^这个符号的均为预测值,不带的为真实值,为超参数。xyc这个下标的意思是在第c类的heatmap上坐标(x,y)这个点的值。下面我们来分类讨论一下。

关键点损失函数,图源[1]

第一种情况,当时,就是标准的正样本的Focal Loss

第二种情况,当时,是在标准的负样本的Focal Loss基础上,加了一个额外的系数。之所以有这个系数,和我们在编码gt时生成的高斯圆有关。我们之前说过,当某点距离中心点很近时,我们赋予了这个点一个接近1的值,随着远离中心点,值也逐渐归为0。因此当接近1时,证明这个点距离中心点很近,可以被接受,因此将会很小,降低对这个点的惩罚力度。反之则加大惩罚力度。

这是目前关键点检测模型的损失函数的通用模版,因为关键点检测基本都采用了heatmap+高斯圆的检测方式。

七、结语

最近公司做了一个目标检测相关的项目,最终选用模型主体为CenterNet,因此读了很多CenterNet相关的文章,本文为这段时间学习的一个自我总结。 关于backbone模型、结果分析、应用实战,都没有涉及。主要考虑到篇幅及时间问题,有兴趣的读者可以阅读[1]自行探索。

不太规范的引用:

[1] https://arxiv.org/abs/1904.07850

[2] https://arxiv.org/abs/1311.2524

[3] https://arxiv.org/abs/1506.02640

[4] https://arxiv.org/abs/1612.08242

[5] https://ieeexplore.ieee.org/document/9233610

[6] https://arxiv.org/abs/1808.01244

[7] https://learnopencv.com/centernet-anchor-free-object-detection-explained/

[8] https://arxiv.org/abs/1708.02002

[9] https://medium.com/visionwizard/understanding-focal-loss-a-quick-read-b914422913e7

  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值