1. 前言
Focal Loss
最初是由Kaiming大神在Focal Loss for Dense Object Detection
一文中提出的,旨在解决目标检测中的数据类别不平衡造成的模型性能问题,也常用于NLP领域。
本质上,Focal Loss
是解决分类问题中类别不均衡、分类难度差异的一个损失函数。
2. 细节
2.1 交叉熵损失函数
C
E
(
p
,
y
)
=
{
−
l
o
g
(
p
)
,
y
=
1
−
l
o
g
(
1
−
p
)
,
y
=
o
t
h
e
r
w
i
s
e
CE(p,y)=\left\{ \begin{matrix} -log(p), y=1 \\ -log(1-p) ,y=otherwise \end{matrix} \right.
CE(p,y)={−log(p),y=1−log(1−p),y=otherwise
令:
p
t
=
{
p
,
y
=
1
1
−
p
,
y
=
o
t
h
e
r
w
i
s
e
p_t=\left\{ \begin{matrix} p, y=1 \\ 1-p ,y=otherwise \end{matrix} \right.
pt={p,y=11−p,y=otherwise
所以:
C
E
(
p
,
y
)
=
C
E
(
p
t
)
=
−
l
o
g
(
p
t
)
CE(p,y)=CE(p_t)=-log(p_t)
CE(p,y)=CE(pt)=−log(pt)
2.2 样本不平衡
对所有样本,其损失函数为:
L
=
1
N
∑
i
=
1
N
l
(
y
i
,
p
^
i
)
L=\frac{1}{N}\sum_{i=1}^Nl(y_i,\hat p_i)
L=N1i=1∑Nl(yi,p^i)
对于二分类问题,损失函数为:
L
=
1
N
(
∑
y
i
=
1
m
−
l
o
g
(
p
^
)
+
∑
y
i
=
0
n
−
l
o
g
(
1
−
p
^
)
)
L=\frac{1}{N}(\sum_{y_i=1}^m-log(\hat p)+\sum_{y_i=0}^n-log(1-\hat p))
L=N1(yi=1∑m−log(p^)+yi=0∑n−log(1−p^))
其中m
为正样本个数,n
为负样本个数,N
为样本总数,
N
=
m
+
n
N=m+n
N=m+n,当样本分布失衡时损失函数的分布会发生倾斜(如
m
<
<
n
m<<n
m<<n时,负样本的损失就会占据损失的主要部分)。由于损失函数倾斜,模型训练过程中会倾向于样本多的类别,从而造成模型对少样本类别的性能差。
2.3 balanced cross entropy
balanced cross entropy
是平衡交叉熵函数
,该函数为交叉熵损失函数增加一个权重因子,用来调整损失函数分布。公式如下:
C
E
(
p
t
)
=
−
α
t
l
o
g
(
p
t
)
CE(p_t)=-\alpha _tlog(p_t)
CE(pt)=−αtlog(pt)
α
\alpha
α是超参数,一般类别样本数量越多
α
\alpha
α值越小。
2.4 focal loss
与balanced cross entropy
不同的是:focal loss
是从loss
的角度解决样本不均衡问题,其公式如下:
F
L
(
p
t
)
=
−
(
1
−
p
t
)
γ
l
o
g
(
p
t
)
FL(p_t)=-(1-p_t)^\gamma log(p_t)
FL(pt)=−(1−pt)γlog(pt)
其中
γ
>
0
\gamma >0
γ>0,是调整因子。当
γ
=
0
\gamma =0
γ=0时,focal loss
等价于corss entorypy
。如下图所示:
3. 特点
(
1
−
p
t
)
γ
(1-p_t)^{\gamma}
(1−pt)γ是调制因子(modulating factor
),从以上公式可得出如下推论:
- 当
p
t
p_t
pt趋于0的时候(样本分类错误,属于
难分类样本
),调制因子趋于1,该部分损失在总loss中基本不受影响。当 p t p_t pt趋于1的时候(样本分类正确,属于易分类样本
),调制因子趋于0,该部分损失在总loss中的权重变小。 - 参数 γ \gamma γ平滑的降低易分类样本损失在总损失的比例,使样本更加专注于学习难分类样本的特征。当 γ = 0 \gamma =0 γ=0的时候,focal loss就是传统的交叉熵损失,可以通过调整 γ \gamma γ实现调制因子的改变。
4. 编码
class WeightedFocalLoss(nn.Module):
"Non weighted version of Focal Loss"
def __init__(self, alpha=.25, gamma=2):
super(WeightedFocalLoss, self).__init__()
self.alpha = torch.tensor([alpha, 1-alpha]).cuda()
self.gamma = gamma
def forward(self, inputs, targets):
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
targets = targets.type(torch.long)
at = self.alpha.gather(0, targets.data.view(-1))
pt = torch.exp(-BCE_loss)
F_loss = at*(1-pt)**self.gamma * BCE_loss
return F_loss.mean()