Focal Loss详解及其pytorch实现

Focal Loss详解及其pytorch实现




引言

Focal Loss是由何恺明等人在2017年的论文《Focal Loss for Dense Object Detection》中提出的。它是一种专门为解决目标检测中类别不平衡和难易样本不平衡问题而设计的损失函数。本文将详细介绍Focal Loss的基本概念、二分类和多分类的交叉熵损失函数,以及如何设置Focal Loss中的关键参数,并提供PyTorch的实现代码。

二分类与多分类的交叉熵损失函数

二分类交叉熵损失

在二分类的任务中,一般使用Sigmoid作为最后的激活函数,输出代表样本为正的概率值 y ^ \hat{y} y^,二分类非正即负,所以样本为负的概率值为 1 − y ^ 1-\hat{y} 1y^。二分类交叉熵损失的计算公式为:

CEL = − y ⋅ log ⁡ ( y ^ ) − ( 1 − y ) ⋅ log ⁡ ( 1 − y ^ ) \text{CEL} = -y \cdot \log(\hat{y}) - (1-y) \cdot \log(1-\hat{y}) CEL=ylog(y^)(1y)log(1y^)

其中 y y y 是实际标签,正样本为1,负样本为0, y ^ \hat{y} y^ 是Sigmoid激活函数的输出值。

多分类交叉熵损失

在多分类的情况下,一般使用Softmax作为最后的激活函数,输出有多个值,对应每个分类的概率值,且这些值之和为1。多分类交叉熵损失的计算公式为:

CEL = − ∑ c = 1 C y c ⋅ log ⁡ ( y ^ c ) = − log ⁡ ( y ^ c ) \text{CEL} = -\sum_{c=1}^{C} y_c \cdot \log(\hat{y}_c) = -\log(\hat{y}_c) CEL=c=1Cyclog(y^c)=log(y^c)

其中 y ^ c \hat{y}_c y^c 表示Softmax激活函数输出结果中第 c c c 类的对应的值, C C C 是类别的总数。

Focal Loss基础概念

关键点理解

要真正理解Focal Loss,有三个关键点需要明确:

  1. 二分类(Sigmoid)和多分类(Softmax)的交叉熵损失表达形式的区别
  2. 理解难分类样本与易分类样本
  3. Focal Loss中的超参数 α \alpha α γ \gamma γ 的作用

什么是难分类样本和易分类样本?

  • 易分类样本:模型预测正确的概率较高,即 y ^ t \hat{y}_t y^t 较大(通常 y ^ t > 0.5 \hat{y}_t > 0.5 y^t>0.5)。
  • 难分类样本:模型预测正确的概率较低,即 y ^ t \hat{y}_t y^t 较小(通常 y ^ t < 0.5 \hat{y}_t < 0.5 y^t<0.5)。

其中 y ^ t \hat{y}_t y^t 定义为:
y ^ t = { y ^ if  y = 1 1 − y ^ otherwise \hat{y}_t = \begin{cases} \hat{y} & \text{if } y = 1 \\ 1 - \hat{y} & \text{otherwise} \end{cases} y^t={y^1y^if y=1otherwise

超参数 γ \gamma γ 的作用

超参数 γ \gamma γ 控制了难分类样本和易分类样本在损失函数中的比重。Focal Loss相对于原始的交叉熵损失增加了 ( 1 − y ^ t ) γ (1 - \hat{y}_t)^\gamma (1y^t)γ 这一项,对原始交叉熵损失进行了衰减。当 γ \gamma γ 增大时,对易分类样本的损失衰减更加明显,从而使模型更加关注难分类样本。

超参数 α \alpha α 的作用

超参数 α \alpha α 用于调整正负样本之间的权重。在二分类中, α \alpha α 的值反映了样本数量较少的类的权重。通常情况下,正样本数量较少(在本文中正样本代表数量少的样本),因此 α \alpha α 值反映了正样本的权重。随着 γ \gamma γ 的增加, α \alpha α 应该稍微降低。这是因为:

  • α \alpha α 对应高 γ \gamma γ。负样本通常容易被正确分类,其权重已经被 ( 1 − y ^ t ) γ (1 - \hat{y}_t)^\gamma (1y^t)γ 显著降低,因此无需给正样本再增加额外过大的权重 α \alpha α
  • 在Focal Loss中, γ \gamma γ 占主要地位,它确保了模型更加关注那些难以正确分类的样本。
  • 当处理负样本时, α \alpha α 的值通常为 1 − α 1 - \alpha 1α,其中 α \alpha α 为正样本的权重。

超参数 α \alpha α 的详细解释

在Focal Loss中, α \alpha α 的作用是调整正负样本之间的权重。理论上,数量越少的类应该具有更大的权重。然而,在原论文作者的实验中,当 α = 0.25 \alpha = 0.25 α=0.25 γ = 2 \gamma = 2 γ=2 时,模型表现最好。这引发了一个问题:为什么正样本的权重( α = 0.25 \alpha = 0.25 α=0.25)反而比负样本的权重( 1 − α = 0.75 1 - \alpha = 0.75 1α=0.75)要低,尤其是当负样本的数量远远多于正样本时?

这是因为Focal Loss的设计初衷是为了减少易分类样本的贡献,让模型更加关注难分类样本。随着 γ \gamma γ 的增加,难分类样本的权重实际上已经被显著提高了。此外,由于负样本通常更容易被正确分类,其权重已经被 ( 1 − y ^ t ) γ (1 - \hat{y}_t)^\gamma (1y^t)γ 大幅降低,因此不需要再额外增加正样本的权重。这意味着,在Focal Loss中, γ \gamma γ 的作用更为关键,而 α \alpha α 的作用则相对次要。

实例计算

假设我们有一个正样本,模型预测的概率为0.8,取 γ = 2 \gamma = 2 γ=2

  1. 计算 y ^ t \hat{y}_t y^t:
    y ^ t = y ^ = 0.8 \hat{y}_t = \hat{y} = 0.8 y^t=y^=0.8

  2. 计算Focal Loss:
    FL ( y ^ t ) = − α t ⋅ ( 1 − 0.8 ) 2 ⋅ log ⁡ ( 0.8 ) \text{FL}(\hat{y}_t) = -\alpha_t \cdot (1 - 0.8)^2 \cdot \log(0.8) FL(y^t)=αt(10.8)2log(0.8)

若取 α = 0.25 \alpha = 0.25 α=0.25,则 α t = 0.25 \alpha_t = 0.25 αt=0.25,因此:
FL ( y ^ t ) = − 0.25 ⋅ ( 0.2 ) 2 ⋅ log ⁡ ( 0.8 ) ≈ − 0.25 ⋅ 0.04 ⋅ ( − 0.22314 ) ≈ 0.00223 \text{FL}(\hat{y}_t) = -0.25 \cdot (0.2)^2 \cdot \log(0.8) \approx -0.25 \cdot 0.04 \cdot (-0.22314) \approx 0.00223 FL(y^t)=0.25(0.2)2log(0.8)0.250.04(0.22314)0.00223

负样本实例

假设我们有一个负样本,模型预测的概率为0.2,取 γ = 2 \gamma = 2 γ=2

  1. 计算 y ^ t \hat{y}_t y^t:
    y ^ t = 1 − y ^ = 1 − 0.2 = 0.8 \hat{y}_t = 1 - \hat{y} = 1 - 0.2 = 0.8 y^t=1y^=10.2=0.8

  2. 计算Focal Loss:
    FL ( y ^ t ) = − α t ⋅ ( 1 − 0.8 ) 2 ⋅ log ⁡ ( 0.8 ) \text{FL}(\hat{y}_t) = -\alpha_t \cdot (1 - 0.8)^2 \cdot \log(0.8) FL(y^t)=αt(10.8)2log(0.8)

若取 α = 0.25 \alpha = 0.25 α=0.25,则 α t = 1 − 0.25 = 0.75 \alpha_t = 1 - 0.25 = 0.75 αt=10.25=0.75,因此:
FL ( y ^ t ) = − 0.75 ⋅ ( 0.2 ) 2 ⋅ log ⁡ ( 0.8 ) ≈ − 0.75 ⋅ 0.04 ⋅ ( − 0.22314 ) ≈ 0.00669 \text{FL}(\hat{y}_t) = -0.75 \cdot (0.2)^2 \cdot \log(0.8) \approx -0.75 \cdot 0.04 \cdot (-0.22314) \approx 0.00669 FL(y^t)=0.75(0.2)2log(0.8)0.750.04(0.22314)0.00669

多分类实例

假设我们有三个类别(猫、狗、兔子),模型预测的概率分别为 [ 0.2 , 0.5 , 0.3 ] [0.2, 0.5, 0.3] [0.2,0.5,0.3],实际标签是狗(one-hot编码为[0, 1, 0]),取 γ = 2 \gamma = 2 γ=2

  1. 计算 y ^ c \hat{y}_c y^c:
    y ^ c = y ^ 2 = 0.5 \hat{y}_c = \hat{y}_2 = 0.5 y^c=y^2=0.5

  2. 计算Focal Loss:
    FL ( y ^ 2 ) = − α 2 ⋅ ( 1 − 0.5 ) 2 ⋅ log ⁡ ( 0.5 ) \text{FL}(\hat{y}_2) = -\alpha_2 \cdot (1 - 0.5)^2 \cdot \log(0.5) FL(y^2)=α2(10.5)2log(0.5)

若取 α 2 = 0.25 \alpha_2 = 0.25 α2=0.25,则:
FL ( y ^ 2 ) = − 0.25 ⋅ ( 0.5 ) 2 ⋅ log ⁡ ( 0.5 ) ≈ − 0.25 ⋅ 0.25 ⋅ ( − 0.69315 ) ≈ 0.04332 \text{FL}(\hat{y}_2) = -0.25 \cdot (0.5)^2 \cdot \log(0.5) \approx -0.25 \cdot 0.25 \cdot (-0.69315) \approx 0.04332 FL(y^2)=0.25(0.5)2log(0.5)0.250.25(0.69315)0.04332

PyTorch实现

二分类Focal Loss

import torch

class FocalLossBinary(torch.nn.Module):
    """
    二分类Focal Loss
    """
    def __init__(self, alpha=0.25, gamma=2):
        super(FocalLossBinary, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, preds, labels):
        """
        preds: sigmoid的输出结果
        labels: 标签
        """
        eps = 1e-7
        loss_1 = -1 * self.alpha * torch.pow((1 - preds), self.gamma) * torch.log(preds + eps) * labels
        loss_0 = -1 * (1 - self.alpha) * torch.pow(preds, self.gamma) * torch.log(1 - preds + eps) * (1 - labels)
        loss = loss_0 + loss_1
        return torch.mean(loss)

多分类Focal Loss

import torch

class FocalLossMultiClass(torch.nn.Module):
    def __init__(self, weight=None, gamma=2):
        super(FocalLossMultiClass, self).__init__()
        self.gamma = gamma
        self.weight = weight
    
    def forward(self, preds, labels):
        """
        preds: softmax输出结果
        labels: 真实值
        """
        eps = 1e-7
        y_pred = preds.view((preds.size()[0], preds.size()[1], -1))  # B*C*H*W->B*C*(H*W)
        
        target = labels.view(y_pred.size())  # B*C*H*W->B*C*(H*W)
        
        ce = -1 * torch.log(y_pred + eps) * target
        floss = torch.pow((1 - y_pred), self.gamma) * ce
        if self.weight is not None:
            floss = torch.mul(floss, self.weight)
        floss = torch.sum(floss, dim=1)
        return torch.mean(floss)

结论

Focal Loss通过引入两个超参数 α \alpha α γ \gamma γ,有效地解决了类别不平衡和难易样本不平衡的问题。通过调整这些超参数,可以使模型更加关注那些难以正确分类的样本,从而提高整体性能。在实际应用中,可以通过实验来确定最佳的 α \alpha α γ \gamma γ 值。

参考文献

Focal Loss的理解以及在多分类任务上的使用(Pytorch) -
GHZhao_GIS_RS - CSDN

  • 19
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值