Focal Loss详解及其pytorch实现

Focal Loss详解及其PyTorch实现



引言

Focal Loss 是由何恺明等人在 2017 年的论文《Focal Loss for Dense Object Detection》中提出的。它是一种专门为解决目标检测中类别不平衡和难易样本不平衡问题而设计的损失函数。本文将详细介绍 Focal Loss 的基本概念、二分类和多分类的交叉熵损失函数,以及如何设置 Focal Loss 中的关键参数,并提供 PyTorch 的实现代码。

二分类与多分类的交叉熵损失函数

二分类交叉熵损失

在二分类任务中,一般使用 Sigmoid 作为最后的激活函数,输出代表样本为正的概率值 y ^ \hat{y} y^。二分类非正即负,所以样本为负的概率值为 1 − y ^ 1 - \hat{y} 1y^。二分类交叉熵损失的计算公式为:

CEL = − y ⋅ log ⁡ ( y ^ ) − ( 1 − y ) ⋅ log ⁡ ( 1 − y ^ ) \text{CEL} = -y \cdot \log(\hat{y}) - (1 - y) \cdot \log(1 - \hat{y}) CEL=ylog(y^)(1y)log(1y^)

其中:

  • y y y 是实际标签,正样本为 1,负样本为 0。
  • y ^ \hat{y} y^ 是 Sigmoid 激活函数的输出值,表示模型预测为正样本的概率。

多分类交叉熵损失

在多分类情况下,一般使用 Softmax 作为最后的激活函数,输出有多个值,对应每个类别的概率值,且这些概率之和为 1。多分类交叉熵损失的计算公式为:

CEL = − ∑ i = 1 n y i ⋅ ln ⁡ ( y ^ i ) \text{CEL} = -\sum_{i=1}^{n} y_i \cdot \ln(\hat{y}_i) CEL=i=1nyiln(y^i)

其中:

  • n n n 是类别的总数。
  • y i y_i yi 是实际标签的 one-hot 编码,若样本属于第 i i i 类,则 y i = 1 y_i = 1 yi=1,否则 y i = 0 y_i = 0 yi=0
  • y ^ i \hat{y}_i y^i 表示 Softmax 激活函数输出结果中第 i i i 类的概率值。

对于批量大小为 N N N 的样本,平均交叉熵损失为:

CEL = − 1 N ∑ j = 1 N ∑ i = 1 n y j i ⋅ ln ⁡ ( y ^ j i ) \text{CEL} = -\frac{1}{N} \sum_{j=1}^{N} \sum_{i=1}^{n} y_{ji} \cdot \ln(\hat{y}_{ji}) CEL=N1j=1Ni=1nyjiln(y^ji)

Focal Loss 基础概念

Focal Loss 公式

Focal Loss 在交叉熵损失的基础上引入了调制因子 ( 1 − p t ) γ (1 - p_t)^\gamma (1pt)γ,其中 p t p_t pt 是模型对真实类别的预测概率。Focal Loss 的公式为:

二分类 Focal Loss 公式:

FL ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) FL(pt)=αt(1pt)γlog(pt)

其中:

  • p t p_t pt 定义为:

    p t = { y ^ , 如果  y = 1 1 − y ^ , 如果  y = 0 p_t = \begin{cases} \hat{y}, & \text{如果 } y = 1 \\ 1 - \hat{y}, & \text{如果 } y = 0 \end{cases} pt={y^,1y^,如果 y=1如果 y=0

    • y y y 是实际标签, y = 1 y = 1 y=1 表示正样本, y = 0 y = 0 y=0 表示负样本。
    • y ^ \hat{y} y^ 是模型预测为正样本的概率。
  • α t \alpha_t αt 是样本平衡因子,定义为:

    α t = { α , 如果  y = 1 1 − α , 如果  y = 0 \alpha_t = \begin{cases} \alpha, & \text{如果 } y = 1 \\ 1 - \alpha, & \text{如果 } y = 0 \end{cases} αt={α,1α,如果 y=1如果 y=0

    • α \alpha α 是控制正负样本权重的超参数,取值范围在 [ 0 , 1 ] [0, 1] [0,1]

多分类 Focal Loss 公式:

FL ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) FL(pt)=αt(1pt)γlog(pt)

其中:

  • p t = y ^ c p_t = \hat{y}_c pt=y^c,表示模型对真实类别 c c c 的预测概率。
  • y ^ c \hat{y}_c y^c 是 Softmax 激活函数输出结果中真实类别的概率值。
  • α t \alpha_t αt 是对应类别的权重,可以是一个标量或一个类别权重向量。

关键点理解

要真正理解 Focal Loss,有三个关键点需要明确:

  1. 二分类(Sigmoid)和多分类(Softmax)的交叉熵损失表达形式的区别
  2. 理解难分类样本与易分类样本
  3. Focal Loss 中的超参数 α \alpha α γ \gamma γ 的作用

什么是难分类样本和易分类样本?

  • 易分类样本:模型预测正确的概率较高,即 p t p_t pt 较大(通常 p t > 0.5 p_t > 0.5 pt>0.5)。
  • 难分类样本:模型预测正确的概率较低,即 p t p_t pt 较小(通常 p t < 0.5 p_t < 0.5 pt<0.5)。

超参数 γ \gamma γ 的作用

超参数 γ \gamma γ 控制了难分类样本和易分类样本在损失函数中的比重。Focal Loss 相对于原始的交叉熵损失增加了 ( 1 − p t ) γ (1 - p_t)^\gamma (1pt)γ 这一项,对原始交叉熵损失进行了调节。当 γ \gamma γ 增大时,对易分类样本的损失衰减更加明显,从而使模型更加关注难分类样本。

超参数 α \alpha α 的作用

超参数 α \alpha α 用于调整类别之间的权重,特别是在类别不平衡的情况下。通常,数量较少的类别应当给予更高的权重。 α \alpha α 的取值范围在 [ 0 , 1 ] [0, 1] [0,1]

  • 在二分类中, α \alpha α 通常用于平衡正负样本,正样本的权重为 α \alpha α,负样本的权重为 1 − α 1 - \alpha 1α
  • 在多分类中, α \alpha α 可以是一个类别权重向量,用于平衡各个类别之间的影响。

超参数 α \alpha α 的详细解释

在 Focal Loss 中, α \alpha α 的作用是调整不同类别样本的影响力。理论上,数量越少的类别应该具有更大的权重。然而,在原论文作者的实验中,当 α = 0.25 \alpha = 0.25 α=0.25 γ = 2 \gamma = 2 γ=2 时,模型表现最好。这引发了一个问题:为什么正样本的权重( α = 0.25 \alpha = 0.25 α=0.25)反而比负样本的权重( 1 − α = 0.75 1 - \alpha = 0.75 1α=0.75)要低,尤其是当负样本的数量远远多于正样本时?

这是因为 Focal Loss 的设计初衷是为了减少易分类样本的贡献,让模型更加关注难分类样本。随着 γ \gamma γ 的增加,难分类样本的权重实际上已经被显著提高了。此外,由于负样本通常更容易被正确分类,其损失已经被 ( 1 − p t ) γ (1 - p_t)^\gamma (1pt)γ 大幅降低,因此不需要再额外增加正样本的权重。这意味着,在 Focal Loss 中, γ \gamma γ 的作用更为关键,而 α \alpha α 的作用则相对次要。

实例计算

正样本实例

假设我们有一个正样本,模型预测的概率为 y ^ = 0.8 \hat{y} = 0.8 y^=0.8,取 γ = 2 \gamma = 2 γ=2

  1. 计算 p t p_t pt

    p t = y ^ = 0.8 p_t = \hat{y} = 0.8 pt=y^=0.8

  2. 计算 Focal Loss

    FL ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) FL(pt)=αt(1pt)γlog(pt)

    若取 α = 0.25 \alpha = 0.25 α=0.25,则 α t = α = 0.25 \alpha_t = \alpha = 0.25 αt=α=0.25,因此:

    FL ( p t ) = − 0.25 × ( 1 − 0.8 ) 2 × ln ⁡ ( 0.8 ) ≈ − 0.25 × 0.04 × ( − 0.22314 ) ≈ 0.00223 \text{FL}(p_t) = -0.25 \times (1 - 0.8)^2 \times \ln(0.8) \approx -0.25 \times 0.04 \times (-0.22314) \approx 0.00223 FL(pt)=0.25×(10.8)2×ln(0.8)0.25×0.04×(0.22314)0.00223

负样本实例

假设我们有一个负样本,模型预测的概率为 y ^ = 0.2 \hat{y} = 0.2 y^=0.2,取 γ = 2 \gamma = 2 γ=2

  1. 计算 p t p_t pt

    p t = 1 − y ^ = 1 − 0.2 = 0.8 p_t = 1 - \hat{y} = 1 - 0.2 = 0.8 pt=1y^=10.2=0.8

  2. 计算 Focal Loss

    FL ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) FL(pt)=αt(1pt)γlog(pt)

    若取 α = 0.25 \alpha = 0.25 α=0.25,则 α t = 1 − α = 0.75 \alpha_t = 1 - \alpha = 0.75 αt=1α=0.75,因此:

    FL ( p t ) = − 0.75 × ( 1 − 0.8 ) 2 × ln ⁡ ( 0.8 ) ≈ − 0.75 × 0.04 × ( − 0.22314 ) ≈ 0.00669 \text{FL}(p_t) = -0.75 \times (1 - 0.8)^2 \times \ln(0.8) \approx -0.75 \times 0.04 \times (-0.22314) \approx 0.00669 FL(pt)=0.75×(10.8)2×ln(0.8)0.75×0.04×(0.22314)0.00669

多分类实例

假设我们有三个类别(猫、狗、兔子),模型预测的概率分别为 y ^ = [ 0.2 , 0.5 , 0.3 ] \hat{y} = [0.2, 0.5, 0.3] y^=[0.2,0.5,0.3],实际标签是狗(one-hot 编码为 [ 0 , 1 , 0 ] [0, 1, 0] [0,1,0]),取 γ = 2 \gamma = 2 γ=2

  1. 计算 p t p_t pt

    p t = y ^ 狗 = 0.5 p_t = \hat{y}_{\text{狗}} = 0.5 pt=y^=0.5

  2. 计算 Focal Loss

    FL ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) FL(pt)=αt(1pt)γlog(pt)

    若取 α t = 0.25 \alpha_t = 0.25 αt=0.25,则:

    FL ( p t ) = − 0.25 × ( 1 − 0.5 ) 2 × ln ⁡ ( 0.5 ) ≈ − 0.25 × 0.25 × ( − 0.69315 ) ≈ 0.04332 \text{FL}(p_t) = -0.25 \times (1 - 0.5)^2 \times \ln(0.5) \approx -0.25 \times 0.25 \times (-0.69315) \approx 0.04332 FL(pt)=0.25×(10.5)2×ln(0.5)0.25×0.25×(0.69315)0.04332

PyTorch 实现

二分类 Focal Loss

import torch

class FocalLossBinary(torch.nn.Module):
    """
    二分类 Focal Loss
    """
    def __init__(self, alpha=0.25, gamma=2):
        super(FocalLossBinary, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, preds, labels):
        """
        preds: Sigmoid 的输出结果,取值范围在 [0, 1]
        labels: 标签,取值为 0 或 1
        """
        eps = 1e-7
        preds = preds.clamp(eps, 1.0 - eps)  # 防止数值稳定性问题
        p_t = preds * labels + (1 - preds) * (1 - labels)
        alpha_t = self.alpha * labels + (1 - self.alpha) * (1 - labels)
        loss = -alpha_t * torch.pow(1 - p_t, self.gamma) * torch.log(p_t)
        return loss.mean()

多分类 Focal Loss

import torch
import torch.nn.functional as F

class FocalLossMultiClass(torch.nn.Module):
    def __init__(self, alpha=None, gamma=2, reduction='mean'):
        super(FocalLossMultiClass, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
    
    def forward(self, preds, labels):
        """
        preds: 模型的输出,未经过 Softmax,形状为 [N, C]
        labels: 标签,形状为 [N],取值为类别的索引
        """
        eps = 1e-7
        # 计算 log-softmax
        preds_logsoft = F.log_softmax(preds, dim=1)
        # 取 softmax 概率值
        preds_softmax = torch.exp(preds_logsoft)
        # 选择真实类别对应的概率
        preds_softmax = preds_softmax.gather(1, labels.unsqueeze(1)).squeeze(1)
        preds_logsoft = preds_logsoft.gather(1, labels.unsqueeze(1)).squeeze(1)
        
        # 处理 alpha 参数
        if self.alpha is not None:
            if isinstance(self.alpha, (list, torch.Tensor)):
                # 若 alpha 为列表或张量,则为类别权重
                alpha = self.alpha[labels]
            else:
                alpha = self.alpha
        else:
            alpha = 1.0

        # 计算 Focal Loss
        loss = -alpha * torch.pow(1 - preds_softmax, self.gamma) * preds_logsoft

        if self.reduction == 'mean':
            return loss.mean()
        else:
            return loss.sum()

结论

Focal Loss 通过引入两个超参数 α \alpha α γ \gamma γ,有效地解决了类别不平衡和难易样本不平衡的问题。通过调整这些超参数,可以使模型更加关注那些难以正确分类的样本,从而提高整体性能。在实际应用中,可以通过实验来确定最佳的 α \alpha α γ \gamma γ 值。

参考文献

Focal Loss 的理解以及在多分类任务上的使用(PyTorch) - GHZhao_GIS_RS - CSDN

### Focal Loss 的定义与实现 Focal Loss 是一种用于解决类别不平衡问题的目标检测损失函数。它通过降低容易分类样本的权重,使得模型更加关注难以分类的样本[^5]。 #### 定义 Focal Loss 的核心思想在于调整交叉熵损失函数中的权重分布。对于标准的二元交叉熵损失 \(L_{CE}\),其形式如下: \[ L_{CE} = -(y \log(p) + (1-y)\log(1-p)) \] 其中 \(p\) 表示预测概率,\(y\) 为真实标签(取值为 0 或 1)。然而,在类别高度不均衡的情况下,大量简单负样本会主导训练过程。为了缓解这一问题,Focal Loss 引入了一个调节因子 \((1-p)^{\gamma}\),从而动态减少易分样本的影响。最终的形式可以表示为: \[ FL(p_t) = -\alpha_t (1-p_t)^{\gamma} \log(p_t) \] 这里: - \(p_t\) 是指目标类别的预测概率; - \(\alpha\) 是平衡正负样本比例的超参数; - \(\gamma\) 控制难易样本之间的权衡程度。 当某个样本被正确分类且置信度较高时 (\(p_t \to 1\)),项 \((1-p_t)^{\gamma}\) 将趋于零,从而使该样本对总损失贡献变小;反之亦然。 #### 实现代码 以下是基于 PyTorchFocal Loss 实现方式之一: ```python import torch import torch.nn.functional as F class FocalLoss(torch.nn.Module): def __init__(self, alpha=1, gamma=2, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) pt = torch.exp(-BCE_loss) focal_loss = self.alpha * (1-pt)**self.gamma * BCE_loss if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() else: return focal_loss ``` 此版本适用于二分类任务,并允许自定义 `alpha` 和 `gamma` 参数来适应不同场景需求。 #### 应用领域 Focal Loss 广泛应用于图像识别、目标检测等领域,特别是在处理严重偏斜的数据集时表现出色。例如 RetinaNet 架构就采用了这种机制以提高性能并简化单阶段检测器的设计流程[^6]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值