论文链接: https://arxiv.org/pdf/1708.02002.pdf
本篇论文是何恺明及其团队17年发表在ICCV上的paper。最初用于目标检测,因为解决了分类中类别不平衡的问题,nlp中也有应用。
下面是各部分的要点,并非全文翻译。
Abstract
目前最高精度的目标检测器是基于由R-CNN推广的 two-stage 方法,其中分类器应用于稀疏的候选对象位置集。相比之下,在可能的物体位置的规则,密集采样上应用的 one-stage 探测器具有更快和更简单的可能性,但迄今为止已经落后于 two-stage 探测器的精度。
我们发现在密集探测器训练过程中遇到的极端前景 - 背景类不平衡是其中心原因。我们提出通过重塑标准交叉熵损失来解决这种类不平衡问题,从而降低分配给分类良好的例子的损失。
我们创新的“focal loss”将训练重点放在困难样本的稀疏集上,并防止大量的简单负例“压倒(overwhelming)”探测器。
为了评估我们损失的有效性,我们设计并训练了一个简单的密集检测器,我们称之为RetinaNet。我们的研究结果表明,当使用焦点损失进行训练时,RetinaNet能够匹配先前 one-stage 探测器的速度,同时超越所有现有技术的 two-stage 探测器的精度。
1. Introduction
目标识别有两大经典结构: 第一类是以RCNN为代表的两级识别方法,更复杂的有特征金字塔网络(FPN)或 Mask R-CNN、Faster R-CNN等等。
这种结构的第一级专注于proposal的提取,迅速将候选对象位置的数量缩小到较小的数量,过滤掉大多数背景样本。第二级则对提取出的proposal进行分类和精确坐标回归,执行采样启发式算法,例如固定的前景 - 背景比(1:3)或在线难例挖掘(OHEM),以保持前景和背景之间的manageable平衡。两级结构准确度较高,但因为第二级需要单独对每个proposal进行分类/回归,速度就打了折扣。
目标识别的第二类结构是以YOLO和SSD为代表的单级结构,它们摒弃了提取proposal的过程,只用一级就完成了识别/回归,虽然速度较快但准确率远远比不上两级结构。那有没有办法在单级结构中也能实现较高的准确度呢?Focal Loss就是要解决这个问题。
在本文中,我们提出了一种新的损失函数,它可以作为处理类不平衡的先前方法的一种更有效的替代方法。损失函数是动态缩放的交叉熵损失,其中缩放因子随着正确类中的置信度增加而衰减为零。直观地说,这个缩放因子可以在训练期间自动降低简单示例的权重,并快速将模型focus在难例(hard example)上。
focal loss的确切形式并不重要,我们展示其他实例可以实现类似的结果。
(计算Loss的bbox可以分为positive和negative两类。当bbox(由anchor加上偏移量得到)与ground truth间的IOU大于上门限时(一般是0.5),会认为该bbox属于positive example,如果IOU小于下门限就认为该bbox属于negative example。在一张输入image中,目标占的比例一般都远小于背景占的比例,所以两类example中以negative为主,这引发了两个问题:
1、negative example过多造成它的loss太大,以至于把positive的loss都淹没掉了,不利于目标的收敛;
2、大多negative example不在前景和背景的过渡区域上,分类很明确(这种易分类的negative称为easy negative),训练时对应的背景类score会很大,换个角度看就是单个example的loss很小,反向计算时梯度小。梯度小造成easy negative example对参数的收敛作用很有限,我们更需要loss大的对参数收敛影响也更大的example,即hard positive/negative example。
这里要注意的是前一点我们说了negative的loss很大,是因为negative的绝对数量多,所以总loss大;后一点说easy negative的loss小,是针对单个example而言。
Faster RCNN的两级结构可以很好的规避上述两个问题。具体来说它有两大法宝:1、会根据前景score的高低过滤出最有可能是前景的example (1K~2K个),因为依据的是前景概率的高低,就能把大量背景概率高的easy negative给过滤掉,这就解决了前面的第2个问题;2、会根据IOU的大小来调整positive和negative example的比例,比如设置成1:3,这样防止了negative过多的情况(同时防止了easy negative和hard negative),就解决了前面的第1个问题。所以Faster RCNN的准确率高。
OHEM是近年兴起的另一种筛选example的方法,它通过对loss排序,选出loss最大的example来进行训练,这样就能保证训练的区域都是hard example。这个方法有个缺陷,它把所有的easy example都去除掉了,造成easy positive example无法进一步提升训练的精度。)
2. Related Work
Classic Object Detectors: 经典目标探测器,滑动窗口范例(sliding-window paradigm),如LeCun等将CNN用于手写数字识别,其他如HOG行人检测,DPM等。
Two-stage Detectors: R-CNN、Region Proposal Networks (RPN)。
One-stage Detectors: OverFeat、SSD、YOLO。但准确性下降了。
Class Imbalance: 单级检测器存在很大的类不平衡问题。这些检测器在每个图像中评估104-105 个候选位置,但只有少数位置包含对象。
这种不平衡导致两个问题:(1)训练效率低下,因为大多数位置容易产生负面影响而没有有用的学习信号; (2)整体而言,easy negatives可以overwhelm训练并导致模型退化。
常见的解决方法: hard negative mining 和 more complex sampling/reweighing schemes。
Robust Estimation: focal loss不是用于解决异常值,而是与鲁棒的损失起着相反的作用:它将训练集中在一组稀疏的难例上。
3. Focal Loss
全文重点。
首先从二元分类的交叉熵引入Focal loss。
交叉熵:
C
E
(
p
,
y
)
=
{
−
log
(
p
)
if
y
=
1
−
log
(
1
−
p
)
otherwise
\mathrm{CE}(p, y)=\left\{\begin{array}{ll}{-\log (p)} & {\text { if } y=1} \\ {-\log (1-p)} & {\text { otherwise }}\end{array}\right.
CE(p,y)={−log(p)−log(1−p) if y=1 otherwise
其中y=-1,+1,p属于[0,1]。定义
p
t
p_t
pt为:
p
t
=
{
p
if
y
=
1
1
−
p
otherwise
p_{\mathrm{t}}=\left\{\begin{array}{ll}{p} & {\text { if } y=1} \\ {1-p} & {\text { otherwise }}\end{array}\right.
pt={p1−p if y=1 otherwise
则交叉熵公式可以重写为:
C
E
(
p
,
y
)
=
C
E
(
p
t
)
=
−
log
(
p
t
)
\mathrm{CE}(p, y)=\mathrm{CE}\left(p_{\mathrm{t}}\right)=-\log \left(p_{\mathrm{t}}\right)
CE(p,y)=CE(pt)=−log(pt)
3.1. Balanced Cross Entropy
解决类不平衡的常用方法是引入加权因子
α
\alpha
α。我们定义
α
\alpha
α类似于定义
p
t
p_t
pt。则
α
\alpha
α-balanced CE损失为:
C
E
(
p
t
)
=
−
α
t
log
(
p
t
)
\mathrm{CE}\left(p_{\mathrm{t}}\right)=-\alpha_{\mathrm{t}} \log \left(p_{\mathrm{t}}\right)
CE(pt)=−αtlog(pt)
我们认为这是我们提议的焦点损失的实验基线。
(当label为1时,权重为a,label为0时,权重为1-a。这个就是简单地对正负样本的损失进行加权。)
3.2. Focal Loss Definition
虽然
α
\alpha
α平衡正/负例的重要性,它没有区分简单/困难的例子。所以我们建议reshape损失函数,以给简单样本降权,从而集中训练困难负例。
我们建议给交叉熵损失函数添加一个modulating factor调制因子
(
1
−
p
t
)
γ
\left(1-p_{\mathrm{t}}\right)^{\gamma}
(1−pt)γ。
focal loss定义为:
F
L
(
p
t
)
=
−
(
1
−
p
t
)
γ
log
(
p
t
)
\mathrm{FL}\left(p_{\mathrm{t}}\right)=-\left(1-p_{\mathrm{t}}\right)^{\gamma} \log \left(p_{\mathrm{t}}\right)
FL(pt)=−(1−pt)γlog(pt)
focal loss具有两个属性:(1)当一个例子被错误分类并且pt很小时,调制因子接近1并且损失不受影响。当pt ->1时,因子变为0,并且分类良好的示例的损失是低权重的。 (2)聚焦参数
γ
\gamma
γ平滑地调整容易样本下降的速率。当
γ
\gamma
γ= 0时,FL相当于CE,并且随着
γ
\gamma
γ增加,调制因子的影响同样增加(我们发现
γ
\gamma
γ= 2在我们的实验中效果最好)。
( γ \gamma γ代表易难样本权重差别的难度, γ \gamma γ越大,差别越大,作者给出了【0,5】的取值范围,如果 γ \gamma γ = 0,那么focal loss退化成为普通的CEloss。考虑到上式的p代表正确分类的概率,p值越大代表预测越准确,p值越小代表预测越不准确。从公式可以看出,当p值趋向于1的时候,权重几乎为0,当p趋向于0时,权重趋向于1,这样就实现了在训练过程中自动的调整难易样本的权重。)
在实践中,我们使用
α
\alpha
α-balanced的变种focal loss:
F
L
(
p
t
)
=
−
α
t
(
1
−
p
t
)
γ
log
(
p
t
)
\mathrm{FL}\left(p_{\mathrm{t}}\right)=-\alpha_{\mathrm{t}}\left(1-p_{\mathrm{t}}\right)^{\gamma} \log \left(p_{\mathrm{t}}\right)
FL(pt)=−αt(1−pt)γlog(pt)
3.3. Class Imbalance and Model Initialization
默认情况下,二值分类模型初始化为输出y = - 1或1的概率相等。在这样的初始化下,在类不平衡的情况下,由于类频率而造成的损失会主导全员损失,导致早期训练的不稳定。为了解决这一问题,在训练开始时引入了“先验”(prior)的概念,即稀有类(前景)the rare class (foreground)的模型估计的p值。我们通过π表示先验,这样模型的估计p稀有类的例子很低,例如0.01。我们注意到这是模型初始化的一个变化(参见§4.1),而不是损失函数的变化。我们发现,在严重不平衡的情况下,这种方法可以提高交叉熵和焦点损失的训练稳定性。
(在训练初始阶段因为positivie和negative的分类概率基本一致,会造成公式1起不到抑制easy example的作用,为了打破这种情况,作者对最后一级用于分类的卷积的bias(具体位置见图2)作了下小修改,把它初始化成一个特殊的值b=-log((1-π)/π)。π在论文中取0.01,这样做能在训练初始阶段提高positive的分类概率。)
3.4. Class Imbalance and Two-stage Detectors
两阶段检测器通常通常用交叉熵损失来训练不用 α \alpha α-balanced或Focal loss。相反,两阶段检测器通过两种机制来解决类不平衡问题:(1)两级联(a two-stage cascade)和(2)有偏的小批量抽样。第一个级联阶段是一个目标建议机制,它将几乎无限的可能目标位置集减少到1000或2000个。重要的是,所选择的建议不是随机的,而是可能与真实的对象位置相对应的,这消除了绝大多数容易产生的负面影响。在培训第二阶段时,偏置抽样通常用于构建包含正、负样本比例为1:3的minibatch。这个比例是一个隐式 α \alpha α-balancing因子实现通过抽样。我们提出的Focal loss是通过损失函数在一阶段检测器中直接解决这个问题。
4. RetinaNet Detector
(这部分我不是做CV的所以不太懂)
Retinanet是由一个主干网和两个特定于任务的子网组成的单一、统一的网络。主干负责计算整个输入图像上的卷积特征图,是一个自定义卷积网络。第一子网对骨干网的输出进行卷积对象分类;第二个子网络执行卷积边界盒回归。这两个子网具有一个简单的设计,我们特别针对单阶段密集检测提出了这个设计,如下图所示。虽然对于这些组件的细节有许多可能的选择,但是大多数设计参数对实验中显示的精确值并不特别敏感。
FPN骨干网: 采用特征金字塔网络(FPN)作为支持网络的骨干网络。简而言之,FPN通过自顶向下的路径和横向连接扩展了一个标准的卷积网络,使得该网络能够有效地从一个分辨率输入图像构建一个丰富的多尺度特征金字塔,如上图(a)-(b)所示。金字塔的每一层都可以用来探测不同尺度的物体。
Anchor: 使用平移不变的锚框,类似于Fast R-CNN中的RPN变体。
分类子网络: 分类子网为每个A锚和k对象类预测对象在每个空间位置出现的概率。这个子网是一个小的FCN附加到每个FPN级;此子网的参数在所有金字塔级别上共享。
回归子网络(Box Regression Subnet): 与对象分类子网并行,我们将另一个小FCN附加到每个金字塔级别,以便将每个锚框的偏移量回归到附近的ground-truth对象(如果存在的话)。
4.1. Inference and Training
Inference: RetinaNet形成一个单一的FCN,由一个ResNet-FPN主干、一个分类子网和一个box regression子网组成,如上图所示。因此,推理只涉及通过网络前向传播图像。为了提高速度,我们在阈值检测器置信度为0.05后,只对每个FPN级别的最高1k得分预测进行解码。将各个级别的最高预测进行合并,并使用阈值为0.5的NMS来产生最终的检测结果。
Focal Loss: γ = 2 , α = 0.25 \gamma=2,\alpha=0.25 γ=2,α=0.25实践效果最好, γ ∈ [ 0.5 , 5 ] \gamma \in[0.5,5] γ∈[0.5,5]RetinaNet相对鲁棒。
Initialization: 分类子网络的最后一个卷积层,把bias初始化为 b = − log ( ( 1 − π ) / π ) , π = 0.01 b=-\log ((1-\pi) / \pi),\pi=0.01 b=−log((1−π)/π),π=0.01。
Optimization: SGD优化,8GPU,每个minibatch16个图像。训练90k次迭代,初始学习率0.01,在60k和80k的时候各除以10.使用水平图像翻转作为唯一的数据增强形式。weight delay0.0001,动量0.9.训练损失是box regression中focal loss与标准光滑L1损失之和。
5. Experiments
我们在具有挑战性的COCO基准的边界框检测轨迹上提供实验结果。
Focal Loss分析
为了弄明白FL为什么更好,我们分析FL的经验分布。在大量随机图片中采用约10^7 个负样本和10^5 个正样本;随后计算FL值,并归一化。对归一化的FL值排序并分别画出正负样本的累积分布函数图。
a) 20%的困难正样本贡献了约一半的损失值;当γ增加时,20%的贡献度更加明显;但是影响不是很大。
b) γ对于负样本的影响完全不同;γ = 0时CDF与正样本类似;当γ增加时,迅速的聚焦于困难负样本,当γ=2时,大量的易分类负样本仅仅贡献了少量损失。
(作者比较了在不同的超参数v下的CDF曲线。CDF(累积分布函数,cumulative distribution function)是将所有的loss进行归一化之后按照从小到大排序,横坐标代表loss数量的比例,纵坐标代表已有的loss相加占总loss的比例,该曲线越弯曲,代表loss之间的差别越大,即说明难学的样本(loss大)在总loss中的权重更大,从而可以更好地指导梯度下降的方向。从上面两个图可以看出,增加
γ
\gamma
γ都会使得CDF曲线变得弯曲,特别是对于background,这种效果更为明显,这也验证了backgroud中存在大量容易学习的样本。)
在线困难样本挖掘(OHEM)
OHEM中所有样本经过计算损失值,然后使用NMS过滤,最后在minibatch时选择损失值最大的那些样本。OHEM关注误分类样本,不同于FL,OHEM完全忽略的易分类样本。如Table 1d所示,最好的OHEM与FL有3.2点AP差距。
6. Conclusion
类不平衡是阻止单级对象检测器超越性能最好的两级方法的主要障碍。为了解决这一问题,我们提出了focal loss,它将调制项应用于交叉熵损失,以便将学习集中在困难的负面例子上。我们的方法简单而高效。我们通过设计一个全卷积的单级检测器来证明它的有效性,并报告了大量的实验分析,表明它达到了最先进的精度和速度。
源代码地址:https://github.com/facebookresearch/Detectron
参考网址:
Focal Loss for Dense Object Detection解读
RetinaNet——《Focal Loss for Dense Object Detection》论文翻译(有部分不准确的地方,但几乎全文翻译了)
Focal Loss for Dense Object Detection(文献阅读)(部分翻译,翻译得比较精细)
Focal Loss for Dense Object Detection读书笔记(要点总结)
Focal Loss for Dense Object Detection解读(含部分背景知识)
何恺明大神的「Focal Loss」,如何更好地理解?(另一种思路)
Focal Loss理解
聊聊Focal Loss及其反向传播