Focal Loss由(Kaiming He at., 2017)提出用于解决One-stage中正负样本不平衡的问题,同时使得网络更能挖掘困难样本的知识。
建议在看之前先看一下交叉熵的介绍:交叉熵损失函数原理详解(这篇文章对交叉熵介绍很透彻)
正负样本:在进行物体检测时,图像中的背景为负样本,物体为正样本。负样本数据大于正样本数据。
简单困难样本:出现频率高样本简单样本,出现频率低的样本为困难样本
原始交叉熵函数:
定义:
则:
解决正负样本不平衡问题:
给正负样本加上权重,负样本出现的频次多,那么就降低负样本的权重,正样本数量少,就相对提高正样本的权重。
其中:
α
∈
[
0
,
1
]
\alpha\in[0,1]
α∈[0,1]正类,
1
−
α
1-\alpha
1−α负类,
α
\alpha
α一般取0.25。
Focal Loss:
虽然(3)可以控制正负样本的权重,但是没法控制简单样本和困难样本的权重。因此进行如下改进:
γ
\gamma
γ:focusing parameter,
γ
∈
[
0
,
5
]
\gamma\in[0,5]
γ∈[0,5],
γ
\gamma
γ一般取2
(
1
−
p
t
)
γ
(1-p_t)^\gamma
(1−pt)γ:调制系数(modulating factor)。
随着对困难样本挖掘,困难样本预测概率也变大,网络对其的关注度降低。
两个重要性质:
1.当一个样本被分错,pt很小,调制因子(1-pt)接近1,损失不被影响;当pt→1及预测概率很好,因子(1-pt)接近0,那么分的比较好的样本的权值就被调低。因此调制系数趋于1,也就是说相比原来的loss是没有什么大的改变的。当pt趋于1的时候(此时分类正确而且是易分类样本),调制系数趋于0,也就是对于总的loss的贡献很小。
2、当γ=0的时候,focal loss就是传统的交叉熵损失,当γ增加的时候,调制系数也会增加。 专注参数γ平滑地调节了易分样本调低权值的比例。γ增大能增强调制因子的影响,实验发现γ取2最好。直觉上来说,调制因子减少了易分样本的损失贡献,拓宽了样例接收到低损失的范围。当γ一定的时候,比如等于2,一样easy example(pt=0.9)的loss要比标准的交叉熵loss小100+倍,当pt=0.968时,要小1000+倍,但是对于hard example(pt < 0.5),loss最多小了4倍。这样的话hard example的权重相对就提升了很多。这样就增加了那些误分类的重要性。
一般使用时将两者融合,既能调整正负样本的权重,又能控制难易分类样本的权重。
一般当γ增加的时候,a需要减小(实验中γ=2,a=0.25的效果最好)
语音识别中,例子:
class FocalLoss(nn.Module):
def __init__(self, ignore_idx, alpha=0.25, gamma=2, smoothing=0, size_average=False):
super(FocalLoss, self).__init__()
self.ignore_idx = ignore_idx
self.alpha = alpha
self.gamma = gamma
self.size_average = size_average
self.smoothing = smoothing
def forward(self, inputs, targets):
N = inputs.size(0)
C = inputs.size(2)
inputs = inputs.view(-1,C)
targets = targets.view(-1, 1)
log_p = F.log_softmax(inputs,dim=1)
#print('probs',probs.shape,probs)
class_mask = inputs.clone().fill_(self.smoothing/(C-1))
class_mask.scatter_(1, targets, 1.0-self.smoothing)
#print('class_mask',class_mask.shape,class_mask)
ce_loss = log_p* class_mask
probs = torch.exp(log_p)
#0 -> 1-a
#1 -> (2a-1) +(1-a)=a
alpha_facter = class_mask*(2*self.alpha - 1)
alpha_facter = alpha_facter + (1-self.alpha)
#print('alpha_facter',alpha_facter.shape,alpha_facter)
batch_loss = - alpha_facter * torch.pow(1-probs, self.gamma)*ce_loss
batch_loss = batch_loss.sum(1).view(-1)
non_pad_mask = (targets != self.ignore_idx).view(-1)
#print('non_pad_mask',non_pad_mask.shape,non_pad_mask)
total = non_pad_mask.sum().float()
preds = probs.max(1)[1]
n_correct = preds.eq(targets.view(-1))
#print('n_correct1:',n_correct)
n_correct = n_correct.masked_select(non_pad_mask).sum().float()
#print('batch_loss',batch_loss.shape,batch_loss)
loss = batch_loss.masked_select(non_pad_mask)
#print('loss',loss.shape,loss)
return loss.sum()/N, n_correct/total*100
参考:
[1] Focal Loss for Dense Object Detection [论文]
[2] Focal loss论文详解 [CSDN]