Focal Loss for Dense Object Detection
论文地址
官方源码 caffe2
ICCV2017 Focal Loss 现场演讲: https://zhuanlan.zhihu.com/p/55869356
Focal Loss
Focal Loss 主要解决类别分类不平衡的问题。作者发现 one-stage 的检测器与 two-stage 的检测器相比,SSD 这种 one-stage 的检测器的类别分布不平衡问题更加严重 (YOLO 中有特殊的采样策略,类别分布不平衡问题不是非常严重,使用 Focal Loss 基本没有提升);two-stage 的 RCNN 系列可以在 Regin Proposal 阶段使用 Selective Search、EdgeBoxes、DeepMask、RPN 等候选区域提取方法过滤掉大部分的背景区域,在第二阶段分类时也可以使用启发式采样,例如固定正负样本比例为 1:3,online hard example mining (OHEM) 等保持样本的类别平衡。
one-shot 型的检测方法需要处理更多的候选区域,大概有100K,如果使用与 two-stage 相似的采样策略,耗费的时间过长,效率很低。
Focal Loss 可以动态缩减交叉熵损失,Focal Loss 的缩放因子可以自动降低简单样本的损失,帮助模型集中于训练更加困难的样本。Focal Loss 的思想与 OHEM 的思想有点类似,OHEM 是仅将损失较大的部分反向传播,直接忽略简单样本的损失,这种直接忽略肯定也会带来一定的影响,所以 Focal Loss 将简单样本的损失降低,而不是直接忽略,可以得到更好的结果。
Focal Loss 是直接在交叉熵损失的基础上改进的,增加了一个动态缩放因子,以二分类使用的二值交叉熵损失 (BCELoss) 举例:
C
E
(
p
,
y
)
=
{
−
l
o
g
(
p
)
i
f
y
=
1
−
l
o
g
(
1
−
p
)
o
t
h
e
r
w
i
s
e
CE(p,y)=\left\{\begin{matrix} -log(p) & if\ y=1\\ -log(1-p) & otherwise \end{matrix}\right.
CE(p,y)={−log(p)−log(1−p)if y=1otherwise
其中,
y
∈
{
±
1
}
y\in\{\pm 1\}
y∈{±1} 表示类别标签,
p
∈
[
0
,
1
]
p\in[0,1]
p∈[0,1],表示模型输出的类别为
1
1
1的概率,为了简便,定义
p
t
=
{
p
i
f
y
=
1
1
−
p
o
t
h
e
r
w
i
s
e
p_t=\left\{\begin{matrix} p & if \ y=1\\ 1-p & otherwise \end{matrix}\right.
pt={p1−pif y=1otherwise
此时,BCELoss就变成了
C
E
(
p
,
y
)
=
C
E
(
p
t
)
=
−
l
o
g
(
p
t
)
CE(p,y)=CE(p_t)=-log(p_t)
CE(p,y)=CE(pt)=−log(pt)
Focal Loss 在交叉熵损失上增加了动态放缩因子
(
1
−
p
t
)
γ
(1-p_t)^\gamma
(1−pt)γ,
γ
\gamma
γ 是一个可调的超参数,可以控制放缩比例,文中实验表明
γ
=
2
\gamma=2
γ=2时的效果最好,
F
L
(
p
t
)
=
−
(
1
−
p
t
)
γ
l
o
g
(
p
t
)
FL(p_t)=-(1-p_t)^\gamma log(p_t)
FL(pt)=−(1−pt)γlog(pt)
另外,在实践中,作者还增加了一个 Focal Loss 的平衡变量
α
\alpha
α,可以提升少量精度,文章中推荐
α
=
0.25
\alpha =0.25
α=0.25时最佳:
F
L
(
p
t
)
=
−
α
t
(
1
−
p
t
)
γ
l
o
g
(
p
t
)
FL(p_t)=-\alpha _t(1-p_t)^\gamma log(p_t)
FL(pt)=−αt(1−pt)γlog(pt)
Focal Loss 要与 Sigmoid 配合使用可以获得更好的数值稳定
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/e96dd9ad118d28694cc03646c3833ffd.png)
论文中这张图表明了不同的 γ \gamma γ的数值下概率与损失函数之间的关系,可以看到网络预测结果概率大的部分的损失小,但是在检测时会产生大量的背景区域,属于简单样本,会使总损失中简单样本的占比过大,而 Focal Loss 可以进一步减小简单样本的损失,但是不至于让简单样本的损失归于零,可以凸显出 hard example 的损失。
Focal Loss 反向求导
标准 Focal Loss 形式:
F
L
(
p
t
)
=
−
α
(
1
−
p
t
)
γ
l
o
g
(
p
t
)
FL(p_t)=-\alpha (1-p_t)^\gamma log(p_t)
FL(pt)=−α(1−pt)γlog(pt)
其中
p
t
p_t
pt
p
t
=
{
p
i
f
y
=
1
1
−
p
o
t
h
e
r
w
i
s
e
p_t=\left\{\begin{matrix} p & if \ y=1\\ 1-p & otherwise \end{matrix}\right.
pt={p1−pif y=1otherwise
softmax 公式为
p
i
=
e
x
i
∑
e
k
p_i=\frac{e^{x_i}}{\sum e^k}
pi=∑ekexi
Focal Loss 求导:
d
F
L
x
i
=
d
F
L
d
p
i
⋅
d
p
i
d
x
i
\frac{dFL}{x_i}=\frac{dFL}{dp_i}\cdot \frac{dp_i}{dx_i}
xidFL=dpidFL⋅dxidpi
其中,
d
F
L
d
p
t
=
−
α
(
d
(
1
−
p
t
)
γ
d
p
t
⋅
l
o
g
(
p
t
)
+
(
1
−
p
t
)
γ
⋅
d
l
o
g
(
p
t
)
d
p
t
)
=
−
α
(
−
γ
(
1
−
p
t
)
γ
−
1
l
o
g
(
p
t
)
+
(
1
−
p
t
)
γ
1
p
t
)
=
−
α
(
−
γ
(
1
−
p
i
)
γ
−
1
l
o
g
(
p
i
)
+
(
1
−
p
i
)
γ
1
p
i
)
\frac{dFL}{dp_t}=-\alpha(\frac{d(1-p_t)^\gamma}{dp_t}\cdot log(p_t)+(1-p_t)^\gamma \cdot \frac{dlog(p_t)}{dp_t}) \\ \quad =-\alpha (-\gamma (1-p_t)^{\gamma -1} log(p_t)+(1-p_t)^\gamma \frac{1}{p_t})\\ \quad =-\alpha (-\gamma (1-p_i)^{\gamma -1} log(p_i)+(1-p_i)^\gamma \frac{1}{p_i})
dptdFL=−α(dptd(1−pt)γ⋅log(pt)+(1−pt)γ⋅dptdlog(pt))=−α(−γ(1−pt)γ−1log(pt)+(1−pt)γpt1)=−α(−γ(1−pi)γ−1log(pi)+(1−pi)γpi1)
x
i
x_i
xi对
s
o
f
t
m
a
x
softmax
softmax求导,分为两种情况:
i
f
i
=
=
j
:
if \quad i==j:
ifi==j:
d
p
i
d
x
i
=
e
x
i
⋅
∑
e
x
k
−
e
x
i
⋅
e
x
i
∑
e
x
k
2
=
e
x
i
∑
e
x
k
−
e
x
i
∑
e
x
k
⋅
e
x
i
∑
e
x
k
=
p
i
−
p
i
⋅
p
i
=
p
i
(
1
−
p
i
)
\frac{dp_i}{dx_i} =\frac{e^{x_i}\cdot \sum e^{x_k}-e^{x_i}\cdot e^{x_i}}{\sum e^{x_k^2}}\\ \qquad \quad =\frac{e^{x_i}}{\sum e^{x_k}}-\frac{e^{x_i}}{\sum e^{x_k}}\cdot \frac{e^{x_i}}{\sum e^{x_k}}\\ \qquad =p_i-p_i \cdot p_i =p_i (1-p_i)
dxidpi=∑exk2exi⋅∑exk−exi⋅exi=∑exkexi−∑exkexi⋅∑exkexi=pi−pi⋅pi=pi(1−pi)
i
f
i
!
=
j
:
if \quad i !=j:
ifi!=j:
d
p
i
d
x
i
=
0
−
e
x
i
⋅
e
x
j
∑
e
x
k
2
=
−
p
i
⋅
p
j
\frac{dp_i}{dx_i}=\frac{0-e^{x_i}\cdot e^{x_j}}{\sum e^{x_k^2}}=-p_i\cdot p_j
dxidpi=∑exk20−exi⋅exj=−pi⋅pj
所以,
i
f
i
=
=
j
:
if \quad i==j:
ifi==j:
d
F
L
d
x
i
=
α
(
−
γ
(
1
−
p
i
)
γ
−
1
l
o
g
(
p
i
)
p
i
+
(
1
−
p
i
)
γ
)
⋅
(
p
i
−
1
)
\frac{dFL}{dx_i}=\alpha(-\gamma (1-p_i)^{\gamma -1}log(p_i)p_i + (1-p_i)^\gamma)\cdot (p_i -1)
dxidFL=α(−γ(1−pi)γ−1log(pi)pi+(1−pi)γ)⋅(pi−1)
i
f
i
!
=
j
:
if \quad i!=j:
ifi!=j:
d
F
L
d
x
i
=
α
(
−
γ
(
1
−
p
i
)
γ
−
1
l
o
g
(
p
i
)
p
i
+
(
1
−
p
i
)
γ
)
⋅
p
j
\frac{dFL}{dx_i}=\alpha(-\gamma (1-p_i)^{\gamma -1}log(p_i)p_i + (1-p_i)^\gamma)\cdot p_j
dxidFL=α(−γ(1−pi)γ−1log(pi)pi+(1−pi)γ)⋅pj
缺点
- 增加了两个超参数 ( α = 0.25 , γ = 2 ) (\alpha = 0.25, \gamma = 2) (α=0.25,γ=2),想要得到好的效果,需要精细调整
pytoch 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLoss(nn.Module):
r"""
This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""
def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
super(FocalLoss, self).__init__()
if alpha is None:
self.alpha = Variable(torch.ones(class_num, 1))
else:
if isinstance(alpha, Variable):
self.alpha = torch.ones(class_num, 1)*alpha
else:
self.alpha = Variable(torch.ones(class_num, 1)*alpha)
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
def forward(self, inputs, targets):
N = inputs.size(0)
C = inputs.size(1)
P = F.softmax(inputs)
class_mask = inputs.data.new(N, C).fill_(0)
class_mask = Variable(class_mask)
ids = targets.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.)
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda()
alpha = self.alpha[ids.data.view(-1)]
probs = (P*class_mask).sum(1).view(-1,1)
log_p = probs.log()
batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss