Focal Loss(ICCV2017 best student paper)

67 篇文章 43 订阅
31 篇文章 6 订阅

Focal Loss由FAIR提出。Kaiming包揽了ICCV2017的最佳论文(Mask R-CNN)和最佳学生论文(Focal Loss)。


按照国际惯例,给出Focal Loss的论文标题和链接:

Focal Loss for Dense Object Detection
http://openaccess.thecvf.com/content_ICCV_2017/papers/Lin_Focal_Loss_for_ICCV_2017_paper.pdf

首先,Focal loss解决了一个什么问题呢?focal loss是一种改进型的损失函数,可以更大程度上增加检测器的性能。

前言

目前的目标检测算法(下文称“检测器”)分为“two-stage”和“one-stage”两种,前者速度慢但性能好,而后者速度快但性能不如前者。首先,这篇文章讨论了“one-stage”检测器性能不好的主要原因高密度检测器中正负样本的不均衡。原文描述如下:

We discover that the extreme foreground-background class imbalance encountered during training of dense detectors is the central cause. 

负样本往往会占据绝大部分,而大量的负样本在梯度反传的过程中会影响已经训练很好的正样本分类性能。所以,文章提出了一种处理样本不均衡的方法:改进标准交叉熵损失函数的形式。对于已经具有很好分类效果的类别,适当降低其损失权重,以“保护检测器”免受大量负样本的干扰。为了评估这个改进型损失函数是否有用,就特地提出了“one-stage”检测器RetinaNet用以测试。

由上图可见,使用了focal loss的一阶段检测器RetinaNet在COCO上测试,可以达到SSD的速度以及某些二阶段检测器的AP。


样本不均衡问题 

样本不均衡是一阶段检测器普遍面临的问题,这些检测器通常需要从一张图象中提取的10^4~10^5个候选位置中选择极少数含有目标的位置。这种样本不均衡会导致量2个问题:

1. 训练低效。大量的训练是无效训练,因为都是一些简单负样本;

2. 大量简单负样本将影响训练,导致模型退化。


Focal Loss

Focal loss是一种基于交叉熵损失函数改进的损失函数,所以先了解常规的交叉熵损失函数:

CE(p,y)=\left\{\begin{matrix}-log(p), y=1 & & & \\ -log(1-p), otherwise & & & \end{matrix}\right.             式(1)

这里的y代表是否为目标,即区分背景和前景,取值只有{1,-1}两种。这里的p代表概率性(取值范围是[0,1]),即模型对目标分类的确定性(有多大把握的意思)。我们来看一下f(x)=-log(x)的曲线图:

由图像可知,-log(x)函数是一个单调递减函数。我们优化神经网络的过程中,都希望损失函数值越小越好。于是,当y=1时(即确定它是目标时),我们就将p调大以减小损失函数值。当y=-1时(即确定为背景时),我们降低p以减小损失函数值。这就是二值交叉熵损失函数对于目标检测算法的意义

式(1)是一个分段函数,我们可以用一个非分段函数简化它,可以引进一个参数p_t

p_t = \left\{\begin{matrix}p,y=1 & & & \\ (1-p),otherwise & & & \end{matrix}\right.

这样,就可以将式(1)简化为:CE(p_t)=-log(p_t)      式(2)

有一种常用的均衡方法叫“α均衡”,就是在交叉熵损失函数前面乘一个系数,以此来调节“不均衡”,如下:

CE(p_t)=-\alpha_tlog(p_t)    式(3)

到目前为止,一点都不难理解。我们可以由此引出Focal Loss的概念,也是在基础的交叉熵损失函数前乘以一项,形式如下:

FL(p_t)=-(1-p_t)^\gamma log(p_t)     式(4)

其中γ>=0,当γ=0时,这就是一个普通的交叉熵损失函数,而在实验中γ=2可以获得最佳效果。同时,Focal Loss也可以结合“α均衡”:FL(p_t)=-\alpha _t(1-p_t)^\gamma log(p_t)    式(5)

式(4)和式(5)就是Focal loss的标准形式了。很简洁。

原文里有一段:

For instance, with γ = 2, an example classified with pt = 0.9 would have 100× lower loss compared with CE and with pt ≈ 0.968 it would have 1000× lower loss. This in turn increases the importance of correcting misclassified examples (whose loss is scaled down by at most 4× for pt ≤ .5 and γ = 2). 

当pt接近1的时候,loss值会非常小。从而不会对模型训练产生影响。采用loss值的方法来确定模型训练的“焦点”,这就是Focal Loss的创新点了。 


从求导角度看Focal Loss

Focal Loss就是在交叉熵损失函数的结构上做了一点点改进,获得了很好的效果。想看提升效果的同学可以直接戳原文看表格,我这里就不截图列出了。木盏本人的习惯就是会在paper reading的同时做一些自己的分析,在这里我从求导的角度上来分析一下FL。

1,先看(1-p_t)^\gamma(γ>0)的图像:p_t的取值范围是[0,1],分别取γ=0.2, 0.5,1,2,3,4(本人所绘制)

从上图可以看出,无论γ怎么取值,(1-p_t)^\gamma在[0,1]上都是一个递减函数。当γ<1时,就是一个凹函数,当γ>1时,就变成一个凸函数。

对式(4)进行讨论,当y=1时:(注意,我已经把pt替换成了p,取值γ=2

FL(p)=(1-p)^2 (-log(p))   式(6)

其导数为:FL' = -[2(1-p)log(p_t)+(1-p)^2\frac{1}{pln10}]   式(7)

当y=-1时:

FL(p)=p^2(-log(1-p))  式(8)

其导数为:FL' = 2plog(1-p)+p^2\frac{1}{(1-p)ln10}   式(9)

我们知道,用反向传播算法进行神经网络优化时(即减小损失函数值),采用偏导计算出梯度,然而偏导计算需要使用链式法则,则focal loss对p的偏导(如式(7)和式(9)所示)将会直接成为下级偏导的“系数”。我们直接观察式(9),这是遇到负样本的情况,当大量负样本出现时,p再遇到负样本是会变成一个很小的值,可以看到式(9)中分别有"2p"和"p^2"作为系数。这样一来,直接让遇到负样本时反向传播的梯度变得接近于0,于是大量出现的负样本就不会造成模型退化了。


Pytorch实现:(来自https://github.com/yatengLG/Focal-Loss-Pytorch

# -*- coding: utf-8 -*-
# @Author  : LG
from torch import nn
import torch
from torch.nn import functional as F

class focal_loss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, num_classes = 3, size_average=True):
        """
        focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)
        步骤详细的实现了 focal_loss损失函数.
        :param alpha:   阿尔法α,类别权重.      当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.25
        :param gamma:   伽马γ,难易样本调节参数. retainnet中设置为2
        :param num_classes:     类别数量
        :param size_average:    损失计算方式,默认取均值
        """
        super(focal_loss,self).__init__()
        self.size_average = size_average
        if isinstance(alpha,list):
            assert len(alpha)==num_classes   # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
            print(" --- Focal_loss alpha = {}, 将对每一类权重进行精细化赋值 --- ".format(alpha))
            self.alpha = torch.Tensor(alpha)
        else:
            assert alpha<1   #如果α为一个常数,则降低第一类的影响,在目标检测中为第一类
            print(" --- Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用 --- ".format(alpha))
            self.alpha = torch.zeros(num_classes)
            self.alpha[0] += alpha
            self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]

        self.gamma = gamma

    def forward(self, preds, labels):
        """
        focal_loss损失计算
        :param preds:   预测类别. size:[B,N,C] or [B,C]    分别对应与检测与分类任务, B 批次, N检测框数, C类别数
        :param labels:  实际类别. size:[B,N] or [B]
        :return:
        """
        # assert preds.dim()==2 and labels.dim()==1
        preds = preds.view(-1,preds.size(-1))
        self.alpha = self.alpha.to(preds.device)
        preds_logsoft = F.log_softmax(preds, dim=1) # log_softmax
        preds_softmax = torch.exp(preds_logsoft)    # softmax

        preds_softmax = preds_softmax.gather(1,labels.view(-1,1))   # 这部分实现nll_loss ( crossempty = log_softmax + nll )
        preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
        self.alpha = self.alpha.gather(0,labels.view(-1))
        loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft)  # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ

        loss = torch.mul(self.alpha, loss.t())
        if self.size_average:
            loss = loss.mean()
        else:
            loss = loss.sum()
        return loss

  • 12
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

木盏

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值