Focal Loss详解及其PyTorch实现
文章目录
引言
Focal Loss 是由何恺明等人在 2017 年的论文《Focal Loss for Dense Object Detection》中提出的。它是一种专门为解决目标检测中类别不平衡和难易样本不平衡问题而设计的损失函数。本文将详细介绍 Focal Loss 的基本概念、二分类和多分类的交叉熵损失函数,以及如何设置 Focal Loss 中的关键参数,并提供 PyTorch 的实现代码。
二分类与多分类的交叉熵损失函数
二分类交叉熵损失
在二分类任务中,一般使用 Sigmoid 作为最后的激活函数,输出代表样本为正的概率值 y ^ \hat{y} y^。二分类非正即负,所以样本为负的概率值为 1 − y ^ 1 - \hat{y} 1−y^。二分类交叉熵损失的计算公式为:
CEL = − y ⋅ log ( y ^ ) − ( 1 − y ) ⋅ log ( 1 − y ^ ) \text{CEL} = -y \cdot \log(\hat{y}) - (1 - y) \cdot \log(1 - \hat{y}) CEL=−y⋅log(y^)−(1−y)⋅log(1−y^)
其中:
- y y y 是实际标签,正样本为 1,负样本为 0。
- y ^ \hat{y} y^ 是 Sigmoid 激活函数的输出值,表示模型预测为正样本的概率。
多分类交叉熵损失
在多分类情况下,一般使用 Softmax 作为最后的激活函数,输出有多个值,对应每个类别的概率值,且这些概率之和为 1。多分类交叉熵损失的计算公式为:
CEL = − ∑ i = 1 n y i ⋅ ln ( y ^ i ) \text{CEL} = -\sum_{i=1}^{n} y_i \cdot \ln(\hat{y}_i) CEL=−i=1∑nyi⋅ln(y^i)
其中:
- n n n 是类别的总数。
- y i y_i yi 是实际标签的 one-hot 编码,若样本属于第 i i i 类,则 y i = 1 y_i = 1 yi=1,否则 y i = 0 y_i = 0 yi=0。
- y ^ i \hat{y}_i y^i 表示 Softmax 激活函数输出结果中第 i i i 类的概率值。
对于批量大小为 N N N 的样本,平均交叉熵损失为:
CEL = − 1 N ∑ j = 1 N ∑ i = 1 n y j i ⋅ ln ( y ^ j i ) \text{CEL} = -\frac{1}{N} \sum_{j=1}^{N} \sum_{i=1}^{n} y_{ji} \cdot \ln(\hat{y}_{ji}) CEL=−N1j=1∑Ni=1∑nyji⋅ln(y^ji)
Focal Loss 基础概念
Focal Loss 公式
Focal Loss 在交叉熵损失的基础上引入了调制因子 ( 1 − p t ) γ (1 - p_t)^\gamma (1−pt)γ,其中 p t p_t pt 是模型对真实类别的预测概率。Focal Loss 的公式为:
二分类 Focal Loss 公式:
FL ( p t ) = − α t ( 1 − p t ) γ log ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) FL(pt)=−αt(1−pt)γlog(pt)
其中:
-
p t p_t pt 定义为:
p t = { y ^ , 如果 y = 1 1 − y ^ , 如果 y = 0 p_t = \begin{cases} \hat{y}, & \text{如果 } y = 1 \\ 1 - \hat{y}, & \text{如果 } y = 0 \end{cases} pt={y^,1−y^,如果 y=1如果 y=0
- y y y 是实际标签, y = 1 y = 1 y=1 表示正样本, y = 0 y = 0 y=0 表示负样本。
- y ^ \hat{y} y^ 是模型预测为正样本的概率。
-
α t \alpha_t αt 是样本平衡因子,定义为:
α t = { α , 如果 y = 1 1 − α , 如果 y = 0 \alpha_t = \begin{cases} \alpha, & \text{如果 } y = 1 \\ 1 - \alpha, & \text{如果 } y = 0 \end{cases} αt={α,1−α,如果 y=1如果 y=0
- α \alpha α 是控制正负样本权重的超参数,取值范围在 [ 0 , 1 ] [0, 1] [0,1]。
多分类 Focal Loss 公式:
FL ( p t ) = − α t ( 1 − p t ) γ log ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) FL(pt)=−αt(1−pt)γlog(pt)
其中:
- p t = y ^ c p_t = \hat{y}_c pt=y^c,表示模型对真实类别 c c c 的预测概率。
- y ^ c \hat{y}_c y^c 是 Softmax 激活函数输出结果中真实类别的概率值。
- α t \alpha_t αt 是对应类别的权重,可以是一个标量或一个类别权重向量。
关键点理解
要真正理解 Focal Loss,有三个关键点需要明确:
- 二分类(Sigmoid)和多分类(Softmax)的交叉熵损失表达形式的区别。
- 理解难分类样本与易分类样本。
- Focal Loss 中的超参数 α \alpha α 和 γ \gamma γ 的作用。
什么是难分类样本和易分类样本?
- 易分类样本:模型预测正确的概率较高,即 p t p_t pt 较大(通常 p t > 0.5 p_t > 0.5 pt>0.5)。
- 难分类样本:模型预测正确的概率较低,即 p t p_t pt 较小(通常 p t < 0.5 p_t < 0.5 pt<0.5)。
超参数 γ \gamma γ 的作用
超参数 γ \gamma γ 控制了难分类样本和易分类样本在损失函数中的比重。Focal Loss 相对于原始的交叉熵损失增加了 ( 1 − p t ) γ (1 - p_t)^\gamma (1−pt)γ 这一项,对原始交叉熵损失进行了调节。当 γ \gamma γ 增大时,对易分类样本的损失衰减更加明显,从而使模型更加关注难分类样本。
超参数 α \alpha α 的作用
超参数 α \alpha α 用于调整类别之间的权重,特别是在类别不平衡的情况下。通常,数量较少的类别应当给予更高的权重。 α \alpha α 的取值范围在 [ 0 , 1 ] [0, 1] [0,1]。
- 在二分类中, α \alpha α 通常用于平衡正负样本,正样本的权重为 α \alpha α,负样本的权重为 1 − α 1 - \alpha 1−α。
- 在多分类中, α \alpha α 可以是一个类别权重向量,用于平衡各个类别之间的影响。
超参数 α \alpha α 的详细解释
在 Focal Loss 中, α \alpha α 的作用是调整不同类别样本的影响力。理论上,数量越少的类别应该具有更大的权重。然而,在原论文作者的实验中,当 α = 0.25 \alpha = 0.25 α=0.25 和 γ = 2 \gamma = 2 γ=2 时,模型表现最好。这引发了一个问题:为什么正样本的权重( α = 0.25 \alpha = 0.25 α=0.25)反而比负样本的权重( 1 − α = 0.75 1 - \alpha = 0.75 1−α=0.75)要低,尤其是当负样本的数量远远多于正样本时?
这是因为 Focal Loss 的设计初衷是为了减少易分类样本的贡献,让模型更加关注难分类样本。随着 γ \gamma γ 的增加,难分类样本的权重实际上已经被显著提高了。此外,由于负样本通常更容易被正确分类,其损失已经被 ( 1 − p t ) γ (1 - p_t)^\gamma (1−pt)γ 大幅降低,因此不需要再额外增加正样本的权重。这意味着,在 Focal Loss 中, γ \gamma γ 的作用更为关键,而 α \alpha α 的作用则相对次要。
实例计算
正样本实例
假设我们有一个正样本,模型预测的概率为 y ^ = 0.8 \hat{y} = 0.8 y^=0.8,取 γ = 2 \gamma = 2 γ=2。
-
计算 p t p_t pt:
p t = y ^ = 0.8 p_t = \hat{y} = 0.8 pt=y^=0.8
-
计算 Focal Loss:
FL ( p t ) = − α t ( 1 − p t ) γ log ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) FL(pt)=−αt(1−pt)γlog(pt)
若取 α = 0.25 \alpha = 0.25 α=0.25,则 α t = α = 0.25 \alpha_t = \alpha = 0.25 αt=α=0.25,因此:
FL ( p t ) = − 0.25 × ( 1 − 0.8 ) 2 × ln ( 0.8 ) ≈ − 0.25 × 0.04 × ( − 0.22314 ) ≈ 0.00223 \text{FL}(p_t) = -0.25 \times (1 - 0.8)^2 \times \ln(0.8) \approx -0.25 \times 0.04 \times (-0.22314) \approx 0.00223 FL(pt)=−0.25×(1−0.8)2×ln(0.8)≈−0.25×0.04×(−0.22314)≈0.00223
负样本实例
假设我们有一个负样本,模型预测的概率为 y ^ = 0.2 \hat{y} = 0.2 y^=0.2,取 γ = 2 \gamma = 2 γ=2。
-
计算 p t p_t pt:
p t = 1 − y ^ = 1 − 0.2 = 0.8 p_t = 1 - \hat{y} = 1 - 0.2 = 0.8 pt=1−y^=1−0.2=0.8
-
计算 Focal Loss:
FL ( p t ) = − α t ( 1 − p t ) γ log ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) FL(pt)=−αt(1−pt)γlog(pt)
若取 α = 0.25 \alpha = 0.25 α=0.25,则 α t = 1 − α = 0.75 \alpha_t = 1 - \alpha = 0.75 αt=1−α=0.75,因此:
FL ( p t ) = − 0.75 × ( 1 − 0.8 ) 2 × ln ( 0.8 ) ≈ − 0.75 × 0.04 × ( − 0.22314 ) ≈ 0.00669 \text{FL}(p_t) = -0.75 \times (1 - 0.8)^2 \times \ln(0.8) \approx -0.75 \times 0.04 \times (-0.22314) \approx 0.00669 FL(pt)=−0.75×(1−0.8)2×ln(0.8)≈−0.75×0.04×(−0.22314)≈0.00669
多分类实例
假设我们有三个类别(猫、狗、兔子),模型预测的概率分别为 y ^ = [ 0.2 , 0.5 , 0.3 ] \hat{y} = [0.2, 0.5, 0.3] y^=[0.2,0.5,0.3],实际标签是狗(one-hot 编码为 [ 0 , 1 , 0 ] [0, 1, 0] [0,1,0]),取 γ = 2 \gamma = 2 γ=2。
-
计算 p t p_t pt:
p t = y ^ 狗 = 0.5 p_t = \hat{y}_{\text{狗}} = 0.5 pt=y^狗=0.5
-
计算 Focal Loss:
FL ( p t ) = − α t ( 1 − p t ) γ log ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) FL(pt)=−αt(1−pt)γlog(pt)
若取 α t = 0.25 \alpha_t = 0.25 αt=0.25,则:
FL ( p t ) = − 0.25 × ( 1 − 0.5 ) 2 × ln ( 0.5 ) ≈ − 0.25 × 0.25 × ( − 0.69315 ) ≈ 0.04332 \text{FL}(p_t) = -0.25 \times (1 - 0.5)^2 \times \ln(0.5) \approx -0.25 \times 0.25 \times (-0.69315) \approx 0.04332 FL(pt)=−0.25×(1−0.5)2×ln(0.5)≈−0.25×0.25×(−0.69315)≈0.04332
PyTorch 实现
二分类 Focal Loss
import torch
class FocalLossBinary(torch.nn.Module):
"""
二分类 Focal Loss
"""
def __init__(self, alpha=0.25, gamma=2):
super(FocalLossBinary, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, preds, labels):
"""
preds: Sigmoid 的输出结果,取值范围在 [0, 1]
labels: 标签,取值为 0 或 1
"""
eps = 1e-7
preds = preds.clamp(eps, 1.0 - eps) # 防止数值稳定性问题
p_t = preds * labels + (1 - preds) * (1 - labels)
alpha_t = self.alpha * labels + (1 - self.alpha) * (1 - labels)
loss = -alpha_t * torch.pow(1 - p_t, self.gamma) * torch.log(p_t)
return loss.mean()
多分类 Focal Loss
import torch
import torch.nn.functional as F
class FocalLossMultiClass(torch.nn.Module):
def __init__(self, alpha=None, gamma=2, reduction='mean'):
super(FocalLossMultiClass, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, preds, labels):
"""
preds: 模型的输出,未经过 Softmax,形状为 [N, C]
labels: 标签,形状为 [N],取值为类别的索引
"""
eps = 1e-7
# 计算 log-softmax
preds_logsoft = F.log_softmax(preds, dim=1)
# 取 softmax 概率值
preds_softmax = torch.exp(preds_logsoft)
# 选择真实类别对应的概率
preds_softmax = preds_softmax.gather(1, labels.unsqueeze(1)).squeeze(1)
preds_logsoft = preds_logsoft.gather(1, labels.unsqueeze(1)).squeeze(1)
# 处理 alpha 参数
if self.alpha is not None:
if isinstance(self.alpha, (list, torch.Tensor)):
# 若 alpha 为列表或张量,则为类别权重
alpha = self.alpha[labels]
else:
alpha = self.alpha
else:
alpha = 1.0
# 计算 Focal Loss
loss = -alpha * torch.pow(1 - preds_softmax, self.gamma) * preds_logsoft
if self.reduction == 'mean':
return loss.mean()
else:
return loss.sum()
结论
Focal Loss 通过引入两个超参数 α \alpha α 和 γ \gamma γ,有效地解决了类别不平衡和难易样本不平衡的问题。通过调整这些超参数,可以使模型更加关注那些难以正确分类的样本,从而提高整体性能。在实际应用中,可以通过实验来确定最佳的 α \alpha α 和 γ \gamma γ 值。