神经网络训练Trick-Focal Loss

Focal Loss由(Kaiming He at., 2017)提出用于解决One-stage中正负样本不平衡的问题,同时使得网络更能挖掘困难样本的知识。
建议在看之前先看一下交叉熵的介绍:交叉熵损失函数原理详解(这篇文章对交叉熵介绍很透彻)
正负样本:在进行物体检测时,图像中的背景为负样本,物体为正样本。负样本数据大于正样本数据。
简单困难样本:出现频率高样本简单样本,出现频率低的样本为困难样本
原始交叉熵函数:
在这里插入图片描述
定义:
在这里插入图片描述
则:
在这里插入图片描述
解决正负样本不平衡问题:
给正负样本加上权重,负样本出现的频次多,那么就降低负样本的权重,正样本数量少,就相对提高正样本的权重。
在这里插入图片描述
其中: α ∈ [ 0 , 1 ] \alpha\in[0,1] α[0,1]正类, 1 − α 1-\alpha 1α负类, α \alpha α一般取0.25。
Focal Loss:
虽然(3)可以控制正负样本的权重,但是没法控制简单样本和困难样本的权重。因此进行如下改进:
在这里插入图片描述
γ \gamma γ:focusing parameter, γ ∈ [ 0 , 5 ] \gamma\in[0,5] γ[0,5] γ \gamma γ一般取2
( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ:调制系数(modulating factor)。
随着对困难样本挖掘,困难样本预测概率也变大,网络对其的关注度降低。
两个重要性质:
1.当一个样本被分错pt很小,调制因子(1-pt)接近1,损失不被影响;当pt→1预测概率很好,因子(1-pt)接近0,那么分的比较好的样本的权值就被调低。因此调制系数趋于1,也就是说相比原来的loss是没有什么大的改变的。当pt趋于1的时候(此时分类正确而且是易分类样本),调制系数趋于0,也就是对于总的loss的贡献很小。
2、当γ=0的时候,focal loss就是传统的交叉熵损失,当γ增加的时候,调制系数也会增加。 专注参数γ平滑地调节了易分样本调低权值的比例。γ增大能增强调制因子的影响,实验发现γ取2最好。直觉上来说,调制因子减少了易分样本的损失贡献,拓宽了样例接收到低损失的范围。当γ一定的时候,比如等于2,一样easy example(pt=0.9)的loss要比标准的交叉熵loss小100+倍,当pt=0.968时,要小1000+倍,但是对于hard example(pt < 0.5),loss最多小了4倍。这样的话hard example的权重相对就提升了很多。这样就增加了那些误分类的重要性。
一般使用时将两者融合,既能调整正负样本的权重,又能控制难易分类样本的权重。
在这里插入图片描述
一般当γ增加的时候,a需要减小(实验中γ=2,a=0.25的效果最好)
语音识别中,例子:

class FocalLoss(nn.Module):
    def __init__(self, ignore_idx, alpha=0.25, gamma=2, smoothing=0, size_average=False):
        super(FocalLoss, self).__init__()
        self.ignore_idx = ignore_idx
        self.alpha = alpha
        self.gamma = gamma
        self.size_average = size_average
        self.smoothing = smoothing
    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(2)
        inputs = inputs.view(-1,C)
        targets = targets.view(-1, 1)
        
        log_p = F.log_softmax(inputs,dim=1)
        
        #print('probs',probs.shape,probs)
        class_mask = inputs.clone().fill_(self.smoothing/(C-1))
        class_mask.scatter_(1, targets, 1.0-self.smoothing)
        #print('class_mask',class_mask.shape,class_mask)
        ce_loss = log_p* class_mask
        probs = torch.exp(log_p)
        #0 -> 1-a
        #1 -> (2a-1) +(1-a)=a
        alpha_facter = class_mask*(2*self.alpha - 1)
        alpha_facter = alpha_facter + (1-self.alpha)
        #print('alpha_facter',alpha_facter.shape,alpha_facter)
        batch_loss = - alpha_facter * torch.pow(1-probs, self.gamma)*ce_loss
        batch_loss = batch_loss.sum(1).view(-1)
        
        non_pad_mask = (targets  != self.ignore_idx).view(-1)
        #print('non_pad_mask',non_pad_mask.shape,non_pad_mask)
        total = non_pad_mask.sum().float()
        preds = probs.max(1)[1]
        n_correct = preds.eq(targets.view(-1))
        #print('n_correct1:',n_correct)
        n_correct = n_correct.masked_select(non_pad_mask).sum().float()
        
        #print('batch_loss',batch_loss.shape,batch_loss)

        loss = batch_loss.masked_select(non_pad_mask)
        #print('loss',loss.shape,loss)
        return loss.sum()/N, n_correct/total*100

参考:
[1] Focal Loss for Dense Object Detection [论文]
[2] Focal loss论文详解 [CSDN]

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值