Focal Loss

Focal Loss for Dense Object Detection

  通过对不同样本的loss进行加权,从而达到聚焦于学习困难样本的方法,该方法普适性很强。

Key words : Sample balance、Hard example、Focusing parameter

Subjects: Computer Vision and Pattern Recognition (cs.CV)

ICCV2017

作者:RBG和Kaiming

Agile Pioneer  

交叉熵的计算形式如下:

FocalLoss定义如下:

  One-stage的目标检测算中存在正负样本不均衡的情况,以及困难样本难以分类的情况。

F L ( p t ) = − α t ( 1 − p t ) γ l o g ( p t ) FL(p_t) = -\alpha_t(1-p_t)^{\gamma}log(p_t) FL(pt)=αt(1pt)γlog(pt)

核心思想


1. 解决正负样本不均衡问题 α ∈ ( 0 , 1 ) \alpha \in (0, 1) α(0,1)

  • 1.1 α \alpha α 是调节由正负样本产生loss的参数,比如一个batch的8个样本中有7个是负样本,那么负样本loss的占比就很高,可以通过这个参数来对负样本产生的loss进行制约,相对提高正样本的loss

  • 1.2 所以在使用的时候 α \alpha α 向量是一个长度和batch_size相同的向量,然后根据类别索引来确定正负样本,来生成对应的loss权重向量[ α \alpha α, …, 1 − α 1 - \alpha 1α,…]

  • 1.3 多类别的时候怎么办呢,可以考虑不同样本的占比作为权重

  • 1.4 对于训练集合中样本比例较大的类别乘以 α \alpha α(<0.5),样本比例较小的类别1- α \alpha α,是正常的逻辑,但是在Focal loss 中是正样本乘以 α \alpha α(<0.5),本来正样本难以“匹敌”负样本,但经过下面介绍的 γ \gamma γ 的“操控”后,也许形势还逆转了,还要对正样本降权,已达到更好的效果


2. 解决困难样本学习问题 γ ∈ [ 0 , + i n f ) \gamma \in [0, +inf ) γ[0,+inf)

  • 2.1 γ \gamma γ 称为 “focusing parameter”,目的是通过减少易分类样本的权重(而不是增加困难样本的权重,这样hard example的权重相对就提升了很多),从而使得模型在训练时更专注于难分类的样本

  • 2.2 当一个样本被分错的时候,p是很小的,那么 γ = ( 1 − P ) \gamma=(1-P) γ=(1P)接近1,损失不被影响

  • 2.3 当P→1,因子 γ = ( 1 − P ) \gamma=(1-P) γ=(1P)接近0,那么分的比较好的(well-classified)样本的权值就被调低了。

  • 2.4 当 γ = 0 \gamma=0 γ=0的时候,focal loss就是传统的交叉熵损失,当 γ \gamma γ增加的时候,调制系数也会增加。 专注参数 γ \gamma γ平滑地调节了易分样本调低权值的比例。 γ \gamma γ增大能增强调制因子的影响,实验发现 γ \gamma γ取2最好。

  • 2.5 这里的 p t p_t pt是预测的onehot结果中对应真实类别的概率。


  直觉上来说,调制因子减少了易分样本的损失贡献,拓宽了样例接收到低损失的范围。当 γ \gamma γ一定的时候,比如等于2,一样easy example(pt=0.9)的loss要比标准的交叉熵loss小100+倍,当pt=0.968时,要小1000+倍,但是对于hard example(pt < 0.5),loss最多小了4倍。这样的话hard example的权重相对就提升了很多。这样就增加了那些误分类的重要性。

作者建议的最佳值:alpha取值为0.25, gamma=2

Focal loss

p t p_t pt - positive prob

F o c a l _ l o s s = { − α ( 1 − p t ) γ l o g ( p t ) y = 1 ( 1 − α ) ( p t ) γ l o g ( 1 − p t ) y = 0 Focal\_loss=\begin{cases} -\alpha (1 - p_t)^{\gamma}log(p_t) & y = 1 \\ (1-\alpha) (p_t)^{\gamma}log(1 - p_t) & y = 0 \\ \end{cases} Focal_loss={α(1pt)γlog(pt)(1α)(pt)γlog(1pt)y=1y=0

问下自己这些问题:

Q:Focal loss是如何计算的,例如p=0.9或p=0.5?

Q:Focal loss如何应用在多分类中, α \alpha α γ \gamma γ 如何取值?

Q:Focal loss中 1 − p t 1-p_t 1pt中的 p t p_t pt对应哪个类别的概率?

这是我基于pytorch写的二分类的LableSmooth结合FocalLoss的代码

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLossWithLabelSmooth(nn.Module):
    def __init__(self,class_num, alpha=0.25, gamma=2, eps=0.06):
        super(FocalLossWithLabelSmooth, self).__init__()
        self.class_num = class_num
        # for label smooth
        self.eps = eps
        # for focal loss
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, pred, label):
        label = label.contiguous().view(-1)
        one_hot_label = torch.zeros_like(pred)
        one_hot_label = one_hot_label.scatter(1, label.view(-1, 1), 1)
        one_hot_label = one_hot_label * (1 - self.eps) + (1 - one_hot_label) * self.eps / (self.class_num - 1)
        # for label smooth
        log_prob = F.log_softmax(pred, dim=1)
        CEloss = (one_hot_label * log_prob).sum(dim=1)
        #print(one_hot_label) 
        # for focal loss
        P = F.softmax(pred, 1)
        class_mask = pred.data.new(pred.size(0), pred.size(1)).fill_(0)
        class_mask = Variable(class_mask)
        ids = label.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)
        probs = (P * class_mask).sum(1).view(-1, 1)
        #print(probs) 
        # if multi-class you need to modify here 
        alpha = torch.empty(label.size()).fill_(1 - self.alpha)
        # TODO: multi class
        alpha[label == 1] = self.alpha                                                                                                                     
        
        if pred.is_cuda and not alpha.is_cuda:                                                                                                             
            alpha = alpha.cuda()                                                                                                                           
        batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * CEloss                                                                                
        loss = batch_loss.mean()                                                                                                                           
        return loss   

参考:
[1] https://zhuanlan.zhihu.com/p/49981234
[2] https://blog.csdn.net/LeeWanzhi/article/details/80069592
[3] https://blog.csdn.net/weixin_44638957/article/details/100733971

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值