【论文学习2】Focal Loss: Focal Loss for Dense Object Detection

日期:2023/12/16
论文:Focal loss for dense object detection
链接:Focal loss for dense object detection
会议:ICCV2017
参考:
[1] focal loss 通俗讲解
[2] CELoss和NLLLoss
[3] softmax 和 log_softmax

因为近期做长尾分类,所以不看目标检测部分,大部分也是参考内容中知乎里的部分

1 Binary Cross-entropy Loss

L ( y , p ^ ) = − y l o g ( p ^ ) − ( 1 − y ) l o g ( 1 − p ^ ) L(y,\hat{p}) = -ylog(\hat p)-(1-y)log(1-\hat p) L(y,p^)=ylog(p^)(1y)log(1p^)
其中y为二分类任务中的0和1,表示background、foreground; p ^ \hat p p^为预测值与GT(既y)的接近程度,越大说明越接近GT。对于一个样本,上式也可以写为:
L ( y , p ^ ) = { − l o g ( p ^ ) ,   y = 1 − l o g ( 1 − p ^ ) ,   y = 0 L(y,\hat p) = \left\{\begin{matrix} -log(\hat p) ,&& \space y=1 \\ -log(1-\hat p),&& \space y=0 \end{matrix}\right. L(y,p^)={log(p^),log(1p^), y=1 y=0

那么对于整个训练集,总损失则为下式。其中m和n分别表示正负样本个数,如果 m ≪ n m\ll n mn,在计算损失的时候负样本就会占主导,即使负样本的损失很小,但是数量很多的话,损失还是会向负样本倾斜,那么模型训练自然也会向负样本倾斜。
L C E = 1 N ( ∑ y i = 1 m l o g ( p ^ ) + ∑ y i = 0 n l o g ( 1 − p ^ ) ) L_{CE} = \frac{1}{N}(\sum_{y_i=1}^mlog(\hat p)+\sum_{y_i=0}^nlog(1-\hat p)) LCE=N1(yi=1mlog(p^)+yi=0nlog(1p^))
以下为原文中的一个comment:
在这里插入图片描述

2 Balanced CE Loss

解决这一问题的直觉方法就是给正负样本损失加个权重,既
L C E = 1 N [ ∑ y i = 1 m α l o g ( p ^ ) + ∑ y i = 0 n ( 1 − α ) l o g ( 1 − p ^ ) ] L_{CE} = \frac{1}{N}[\sum_{y_i=1}^m\alpha log(\hat p)+\sum_{y_i=0}^n(1-\alpha)log(1-\hat p)] LCE=N1[yi=1mαlog(p^)+yi=0n(1α)log(1p^)]
其中 α \alpha α是一个超参数,如果按照正负样本的频率进行取值的话,那就是 α 1 − α = n m \frac{\alpha}{1-\alpha}=\frac{n}{m} 1αα=mn。可以看到, BCE Loss其实是按照正负样本的数量在进行对数似然权重的调整。

3 Focal Loss

在这里插入图片描述

BCE Loss通过类频率去改变了对数似然的权重,但它并不区分样本的难易程度。Focal loss也是解决类别不平衡的一种loss,它和BCE loss的角度不一样,Focal loss的思想就是从样本预测的难易程度下手。

其原始形式是:
L f l = { − ( 1 − p ^ ) γ l o g ( p ^ ) ,   y = 1 − p ^ γ l o g ( 1 − p ^ ) ,   y = 0 L_{fl} = \left\{\begin{matrix} -(1-\hat p)^\gamma log(\hat p) ,&& \space y=1 \\ -\hat p^\gamma log(1-\hat p),&& \space y=0 \end{matrix}\right. \\ Lfl={(1p^)γlog(p^),p^γlog(1p^), y=1 y=0

p t = { p ^ , y = 1 1 − p ^ , y = 0 p_t = \left\{\begin{matrix} \hat p,&& y=1\\ 1-\hat p,&& y=0 \end{matrix}\right. \\ pt={p^,1p^,y=1y=0
则focal loss可以写成统一形式:
L f l = − ( 1 − p t ) γ l o g ( p t ) L_{fl} = -(1-p_t)^\gamma log(p_t) Lfl=(1pt)γlog(pt)
同理也可以让CE Loss写成统一形式:
L C E = − l o g ( p t ) L_{CE} = -log(p_t) LCE=log(pt)
可以发现,当 γ = 0 \gamma=0 γ=0,Focal loss就是CE loss,当 γ ≠ 0 \gamma \not= 0 γ=0,Focal loss比CE loss多了一个modulating factor ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ。分析一下可以发现:
1) p t → 1 p_t \to 1 pt1时, modulating factor → 0 \to0 0 L f l → 0 L_{fl}\to0 Lfl0。既如果样本很容易,那么对损失的贡献应该很小。
2) p t → 0 p_t \to 0 pt0时, modulating factor → 1 \to1 1 L f l → L C E L_{fl}\to L_{CE} LflLCE。既如果样本很难,对损失贡献很大。

所以说,当预测难度降低的时候,损失倾向于0; 当预测难度增大的时候,损失倾向于维持CE loss,这就增加了预测难度高的(hard example)样本的损失值了。其威力在论文里叙述如下所示,如果 γ = 2 , p t = 0.9 \gamma=2,p_t=0.9 γ=2,pt=0.9的话,对损失的贡献会下降100倍;
在这里插入图片描述

但是在实际中,会使用 α \alpha α平衡变体,因为实验发现相比于non- α \alpha α会提升一些精度,并且使用sigmoid函数去计算 p p p会使得数值产生稳定性

最后的focal loss形式如下👇
L f l = − α t ( 1 − p t ) γ l o g ( p t ) L_{fl} = -\alpha_t(1-p_t)^\gamma log(p_t) Lfl=αt(1pt)γlog(pt)

4 Focal loss 和 CE loss各种探讨

在这里插入图片描述

  • 这幅图的横坐标指的是预测的概率,由以上分析可知,预测概率高(容易预测)的其损失值应该小,而预测概率低的(hard example)其损失值应该跟CE loss一样。
  • 可以发现focal loss相比于CE loss,在well-classified examples上的损失值都要小,并且 γ \gamma γ的作用就是让损失值曲线更陡峭。

在这里插入图片描述

  • 左边的图主要比较的是BCE loss中 α \alpha α因子的影响,可以发现当 α = 0.75 \alpha=0.75 α=0.75时的效果最好,既正样本给75%的权重,负样本给25%的权重。但是应该跟数据集有关吧。
  • 右边的图是Focal loss的两个超参的消融实验,可以发现最佳设置为 ( γ = 2. , α = 0.25 ) (\gamma=2., \alpha=0.25) (γ=2.,α=0.25)

在这里插入图片描述

  • 知乎上还看到一个有意思的解读,就是 OHEM的AP要比OHEM 1:3的AP要好,这就说明类的个数的不平衡并不是影响实验效果的主要,预测难易程度样本的不平衡才是关键。

在这里插入图片描述

  • 这是在正负样本上对 γ \gamma γ的一些消融实验
  • 左图可以说明,大约20%的正样本占了累积loss的一半, γ \gamma γ变大,也只会使这一数值好一点点,但是影响不大,说明对于hard examples来说,它是更接近与CE Loss的,既modulating factor接近于1。
  • 右图可以说明, γ \gamma γ的增大对负样本loss的影响很大,说明让更多容易预测的负样本的损失值变小了。
    在这里插入图片描述
    在这里插入图片描述
  • 这里 x t > 0 x_t>0 xt>0就说明 p t > 0.5 p_t>0.5 pt>0.5(论文里有假设),由梯度图可以看出,对于FL及FL*(FL的变体)而言,当>0时,他们的梯度都很快趋近于0;这不同于CE Loss,CE loss在>0时还是有很大一部分梯度存在的。

5 代码

以下为3个github库的参考:

  1. 有中文解释
  2. 包含有很多loss的仓库
  3. 另一个900+⭐的仓库

代码及注释:
1) focal loss

import torch
from torch import nn
from torch.nn import functional as F

class FocalLoss(nn.Module):
    def __init__(self,alpha=None,gamma=2,num_classes=1000,size_average=True):
        '''
        Focal Loss = -\alpha*(1-y_i)**\gamma*CE_Loss
        :param alpha : 类别权重,1)为列表时,为各类别权重; 2)为常数时,类别权重为[\alpha,1-\alpha,1-\alpha...] 抑制背景类
        :param gamma : 难易样本调节因子,默认为2
        :param num_classes : 类别数
        :param size_average : 默认取平均

        :attention: 关于alpha是这样设置的:
        1) 如果传入的alpha为None,则默认alpha不存在,既为常数1
        2) 如果传入的alpha为int或float,则设置为[alpha,1-alpha,1-alpha....],默认第一类为头类(背景)
        3) 如果传入的alpha为list,判断是否等于类别大小,然后各个类对应alpha
        '''
        super().__init__()
        self.size_average = size_average
        self.alpha = alpha
        if isinstance(alpha,list):
            assert len(alpha) == num_classes
            self.alpha = torch.Tensor(alpha)
        elif isinstance(alpha,(int,float)):
            self.alpha = torch.zeros(num_classes)
            self.alpha[0] = alpha
            self.alpha[1:] = 1-alpha
        elif alpha is None:
            self.alpha = torch.ones(num_classes)
        
        self.gamma = gamma

        # print('Focal Loss:\n')
        # print('- Alpha: ',self.alpha)
        # print('\n- Gamma: ',self.gamma)        

    def forward(self,preds,labels):
        '''
        损失计算, B为batch size, C为类别数
        :param preds: 分类为[B,C]
        :param labels: 分类为[B]
        :return: loss
        '''
        if preds.dim()>2: # 一般来说分类任务是[B,C] 既每张图一个类别
            preds = preds.view(preds.size(0),preds.size(1),-1) # N,C,H,W -> N,C,H*W; 假如[128,1000,8,4] ->[128,1000,32]
            preds = preds.transpose(1,2) # N,C,H*W -> N,H*W,C ; [128,1000,32] -> [128,32,1000]
            preds = preds.contiguous().view(-1,preds.size(2)) # N*H*W,C ; [4096,1000]
        preds = preds.view(-1,preds.size(-1)) # [B,C]

        preds_logsoft = F.log_softmax(preds,dim=1) # log_softmax:先softmax后再log
        preds_softmax = torch.exp(preds_logsoft) # e*(log*softmax) = softmax 概率p


        p_t = preds_softmax.gather(1,labels.view(-1,1)) # GT标签对应的概率
        preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1)) # CE Loss = softmax + log + nll loss
        
        alpha = self.alpha.to(preds.device) 
        alpha = alpha.gather(0,labels.view(-1))

        loss = -torch.mul(torch.pow((1-p_t),self.gamma),preds_logsoft)
        loss = torch.mul(alpha,loss.t())
        if self.size_average:   
            return loss.mean()
        else:
            return loss.sum()       

2) demo

from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

import os,sys,random,time
import argparse

from Focal_loss import FocalLoss


start_time = time.time()
maxe = 0
for i in range(1000):
    x = torch.rand(12800,2)*random.randint(1,10)
    x = Variable(x.cuda())
    l = torch.rand(12800).ge(0.1).long()
    l = Variable(l.cuda())

    output0 = FocalLoss(gamma=0)(x,l)
    output1 = nn.CrossEntropyLoss()(x,l)
    a = output0.item()
    b = output1.item()
    if abs(a-b)>maxe: maxe = abs(a-b)
print('time:',time.time()-start_time,'max_error:',maxe)


start_time = time.time()
maxe = 0
for i in range(100):
    x = torch.rand(128,1000,8,4)*random.randint(1,10)
    x = Variable(x.cuda())
    l = torch.rand(128,8,4)*1000   # 1000 is classes_num
    l = l.long()
    l = Variable(l.cuda())

    output0 = FocalLoss(gamma=0)(x,l)
    output1 = nn.NLLLoss2d()(F.log_softmax(x),l)
    a = output0.item()
    b = output1.item()
    if abs(a-b)>maxe: maxe = abs(a-b)
print('time:',time.time()-start_time,'max_error:',maxe)

在这里插入图片描述

  • 18
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值