CE Loss,BCE Loss以及Focal Loss的原理理解

部署运行你感兴趣的模型镜像

一、交叉熵损失函数(CE Loss,BCE Loss)

最开始理解交叉熵损失函数被自己搞的晕头转向的,最后发现是对随机变量的理解有偏差,不知道有没有读者和我有着一样的困惑,所以在本文开始之前,先介绍一下随机变量是啥。

什么是概率分布?
概率分布,是指用于表述随机变量取值的概率规律。随机变量的概率表示了一次试验中某一个结果发生的可能性大小 ,想象画在图上就是横坐标(自变量)是随机变量。根据随机变量所属类型的不同,概率分布取不同的表现形式。举个最简单的例子:抛一枚硬币,随机变量为抛硬币的结果,产生的结果的概率分布为:p(正面)=0.5,p(背面)=0.5

随机变量是什么?
随机变量是将随机试验的结果数量化,具有随机性的,注意是结果!!!在概率论中,概率质量函数(probability mass function,简写为pmf)是离散随机变量在各特定取值上的概率。一个概率质量函数的图像。函数的所有值必须非负,且总和为1。

如在抛50次硬币这个事件中,随机变量是指抛硬币获得正面的次数。不要把随机变量理解为试验的次数的取值!!!再拿二分类任务举个例子,二分类的随机变量就是看做0和1两个类别。二分类猫狗任务就相当于二项分布中的伯努利分布(试验次数为1时就叫伯努利分布,就相当于只丢一次硬币),因为去识别一张图片,最后试验的结果只能要么是猫要么是狗,这任务中的随机变量不是每一个训练样本(训练集中的每一张图片),而是分类的结果即猫or狗!在训练过程中,如果用交叉熵损失函数,假如p(x)是目标真实的分布,而q(x)是预测得来的分布。网络对每一个训练样本来讲,这张图片经过网络输出后得到的q(x)尽可能和这张图像的p(x)分布相等,x为类别的随机变量,x1为猫,x2为狗。如p(x1)=1,就是表示这张图片得到的x1这个类别的结果概率是1,所以由标签可知它的真实分布即p就是p(猫,狗)~(1,0),从训练来讲就是让这张训练样本图片经过网络输出后,得到的q(x)去无限接近上面p(猫,狗)-(1,0)这个分布。 拟合分布就是让预测分布的参数不断接近分布的参数!如p就是伯努利分布中的参数。所谓的交叉熵的交叉就是指这两个分布之间的交叉,让两个分布越接近则交叉熵损失越小。

要充分理解交叉熵损失函数,首先要理解相对熵,又称互熵。设p(x)和q(x)是两个概率分布,相对熵用来表示两个概率分布的差异,当两个随机分布相同时,它们的相对熵为零,当两个随机分布的差别增大时,它们的相对熵也会增大。

而相对熵=交叉熵-信息熵!!!
由于在机器学习和深度学习中,样本和标签已知(即p已知,样本就是xi),那么信息熵H(p)相当于常量,此时,只需拟合交叉熵,使交叉熵拟合为0即可。关键点:所以最小化交叉熵损失函数就相当于使得交叉熵公式里的p和q这两个概率分布(指交叉熵公式里的那两个乘法因子)的差异最小!式子中的n就是随机变量的取值集合,在这里就是类别数,p(xi)就是事件X=xi的概率。
在这里插入图片描述
信息熵(公式里的两个乘法因子都是指同一个分布的):
信息熵则是在结果出来之前对可能产生的信息量的期望信息量表示一条信息消除不确定性的程度,如中国目前的高铁技术世界第一,这个概率为1,这句话本身是确定的,没有消除任何不确定性。而中国的高铁技术将一直保持世界第一,这句话是个不确定事件,包含的信息量就比较大。信息量的大小和事件发生的概率成反比。信息熵越小就表示这个事件发生的概率越大,-logP就是信息量的公式(P表示事件发生的概率)。
在这里插入图片描述
交叉熵(公式是针对一个样本的,公式里的两个乘法因子分别指两个分布,n为类别数):
在这里插入图片描述

下面进入正题,也就是BCE Loss和CE Loss:

对于二分类交叉熵,下图的x1和x2是指两个类别,比如x1和x2分别代表猫和狗两类,p就是这个样本为猫的标签,这个标签可能是0也有可能是1;q就是这个样本被预测为猫的概率!
在这里插入图片描述

下图给出了多分类问题(实现为F.cross_entropy)和二分类问题(实现为F.binary_cross_entropy)的交叉熵损失公式,下图中多分类问题中的公式是针对单个样本的,公式里的i表示每一个类别。而对于二分类问题的公式即BCE loss,公式里的i表示每一个样本,所以要注意区分! 对于多分类问题即CE loss,假设真实标签的one-hot编码是:[0,0,…,1,…,0],预测的softmax概率为[0.1,0.3,…,0.4,…,0.1],那么Loss=-log(0.4)。对于二分类问题即BCE loss来说,每个样本就输出一个数字。
在这里插入图片描述

需要注意的是,BCE loss在pytorch中实现多分类损失时,也就是通过多个二分类来实现多分类时,target要转换成one-hot形式(只能有1个元素为1,其余都为0)。如下图所示,下图就是一个用BCE loss实现6分类的例子,BCE loss就把这个问题当成6个二分类实现,因为一个目标只能是属于一个类别,所以可以转换成one-hot形式。然后对于用BCE loss处理多分类问题的情况,最后其实返回的是每个类别的二分类损失求和的平均值,所以真正返回的是:4.7938/6 = 0.7990
在这里插入图片描述

二、Focal loss

Focal loss的本质

  1. 首先给出原始二分类交叉熵的公式:

在这里插入图片描述

  1. 在二分类交叉熵损失的基础上,控制了正负样本的权重来解决了正负样本的不平衡,下图就是基于二分类交叉熵损失通过α来控制正负样本比例的例子,当α=0.5时,正负样本的比重是一样的。
    在这里插入图片描述
  2. 在上面图中损失的基础上,增加控制“容易分类和难分类样本的权重”来解决难例挖掘的问题。
  3. 结合这两个方法,就是最终的二分类的Focal loss(如下图所示),最前面红框的第一项是最普通的交叉熵;第二项是控制正负样本平衡的α参数;第三项是控制难易分类样本的平衡,即对于正样本而言,预测分数越接近于1的表示这个样本越简单,那么这个样本应该对损失的影响越小:
    在这里插入图片描述
  4. 同理,多分类的Focal loss(softmax)的公式如下图所示:

这里是引用在这里插入图片描述

Focal loss的具体代码实现

# 参考了:
# 1. https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
# 2. https://github.com/c0nn3r/RetinaNet/blob/master/focal_loss.py

import torch
import torch.nn.functional as F

def focal_loss(logits, labels, gamma=2, reduction="mean"):
    r"""
    focal loss for multi classification(简洁版实现)

    `https://arxiv.org/pdf/1708.02002.pdf`

    FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t)
    """

    # 这段代码比较简洁,具体可以看作者是怎么定义的,或者看 focal_lossv1 版本的实现
    # 经测试,reduction 加不加结果都一样,但是为了保险,还是加上
    # logits是过激活函数前的值,reduction="none"就是不对loss进行求mean或者sum 保留每个样本的CE loss
    ce_loss = F.cross_entropy(logits, labels, reduction="none")
    log_pt = -ce_loss
    pt = torch.exp(log_pt)
    weights = (1 - pt) ** gamma
    fl = weights * ce_loss

    if reduction == "sum":
        fl = fl.sum()
    elif reduction == "mean":
        fl = fl.mean()
    else:
        raise ValueError(f"reduction '{reduction}' is not valid")
    return fl


def balanced_focal_loss(logits, labels, alpha=0.25, gamma=2, reduction="mean"):
    r"""
    带平衡因子的 focal loss,这里的 alpha 在多分类中应该是个向量,向量中的每个值代表类别的权重。
    但是为了简单起见,我们假设每个类一样,直接传 0.25。
    如果是长尾数据集,则应该自行构造 alpha 向量,同时改写 focal loss 函数。
    """
    return alpha * focal_loss(logits, labels, gamma, reduction)



def focal_lossv1(logits, labels, gamma=2):
    r"""
    focal loss for multi classification(第一版)

    FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t)
    """

    # pt = F.softmax(logits, dim=-1)  # 直接调用可能会溢出
    #什么是softmax的溢出:https://blog.csdn.net/qq_35054151/article/details/125891745
    # 一个不会溢出的 trick
    log_pt = F.log_softmax(logits, dim=-1)  # 这里相当于 CE loss
    #pt:tensor([[0.1617, 0.2182, 0.2946, 0.3255],
    #    [0.2455, 0.2010, 0.3314, 0.2221]])
    pt = torch.exp(log_pt)  # 通过 softmax 函数后打的分
    labels = labels.view(-1, 1)  # 多加一个维度,为使用 gather 函数做准备
    #.gather第一个参数表示根据哪个维度,第二个参数表示按照索引列表index从input中选取指定元素
    pt = pt.gather(1, labels)  # 从pt中挑选出真实值对应的 softmax 打分,也可以使用独热编码实现
    #pt,因为只有两个样本所以只有两项损失: tensor([[0.2182],
    #                                      [0.2221]])
    ce_loss = -torch.log(pt)
    weights = (1 - pt) ** gamma
    #对应元素相乘
    fl = weights * ce_loss
    #大家都是默认取均值而不是取sum
    fl = fl.mean()
    return fl


if __name__ == "__main__":
    #2个样本,4分类问题
    logits = torch.tensor([[0.3, 0.6, 0.9, 1], [0.6, 0.4, 0.9, 0.5]])
    labels = torch.tensor([1, 3])
    print(focal_loss(logits, labels))
    print(focal_loss(logits, labels, reduction="sum"))
    print(focal_lossv1(logits, labels))
    print(balanced_focal_loss(logits, labels))

Refer
交叉熵损失原理详解
随机变量的理解
GAN交叉熵
从二分类(二项分布)到多分类(多项分布)
FocalLoss 对样本不平衡的权重调节和减低损失值

再记录几个好的文章非常实用:
一文搞懂F.cross_entropy的具体实现
一文搞懂F.binary_cross_entropy以及weight参数
softmax loss详解,softmax与交叉熵的关系
二分类问题,应该选择sigmoid还是softmax?

您可能感兴趣的与本文相关的镜像

GPT-oss:20b

GPT-oss:20b

图文对话
Gpt-oss

GPT OSS 是OpenAI 推出的重量级开放模型,面向强推理、智能体任务以及多样化开发场景

<think> 首先,用户查询是关于Focal Loss原理和实现细节的详细讲解。我需要基于系统指令和提供的引用来构建响应。 系统指令回顾: - 所有行内数学表达式使用$...$格式。 - 独立公式使用$$...$$格式并单独成段。 - LaTeX语法正确。 - 使用中文回答。 - 生成相关问题。 - 回答中引用的段落末尾自然地添加引用标识。 用户提供了几个引用: - [^1]: Focal Loss讲解,提到参数示例,AP提升效果。 - [^2]: focal loss详解,解释样本不均衡,权重调整原理。 - [^3]: Focal loss原理解析,描述如何在交叉熵基础上修改。 - [^4]: Focal loss详解及其实现,包括公式推导和实现细节。 - [^5]: focal loss原理及简单代码实现,进一步解释公式。 用户的上一个问题是在第一次交互中,但在这个对话中,这是第一个用户消息,所以我需要直接从这些引用出发。 用户消息:"我想获取focal loss原理和实现细节的详细讲解 请问focal loss原理及实现细节详细讲解" 所以,我的响应应该清晰、结构化地讲解Focal Loss原理和实现细节。 结构化响应: 1. **介绍Focal Loss**:简要说明是什么、为什么需要它(解决类别不均衡问题)。 2. **原理讲解**: - 基于交叉熵损失。 - 引入调制因子来降低易分类样本的权重。 - 公式推导和解释。 3. **实现细节**: - 数学公式的实现。 - Python代码示例(类似快速排序的示例)。 4. **总结效果**:如引用中提到的AP提升。 5. **生成相关问题**:在末尾添加相关问题。 关键公式: - 从引用[4]和[5]:Focal Loss公式为 $$ FL(p_t) = - (1 - p_t)^\gamma \log(p_t) $$ - 其中 $p_t$ 是模型预测的正确类别的概率。 - $\gamma$ 是一个超参数,焦点因子(focusing parameter)。 在原理讲解中: - 解释$p_t$:对于分类,如果真实标签为y=1,则$p_t = p$(预测概率);如果y=0,则$p_t = 1-p$。 - 调制因子$(1-p_t)^\gamma$:当$p_t$接近1时,样本易分类,权重小;当$p_t$接近0时,样本难分类,权重大。 - 引入权重系数$\alpha$:引用[4]提到可以添加$\alpha_t$来控制正负样本的权重,最终公式如引用[5]:$$ FL(p_t) = - \alpha_t (1 - p_t)^\gamma \log(p_t) $$ - 其中$\alpha_t$是针对类别的权重。 实现细节: - 在Python中,可以使用PyTorch或TensorFlow实现。 - 提供一个代码片段展示如何计算Focal Loss。 参考引用: - 在回答中,自然地添加引用标识,例如在解释结束时添加[^1], [^2], 等。 - 引用是基于用户提供的,所以我会标记为[^1], [^2], etc. 生成相关问题: - 在末尾,添加类似示例的"相关问题"部分。 - 问题应该基于Focal Loss,如优化、应用等。 使用中文回答:整个响应要用中文。 数学格式: - 行内公式用$...$,例如$p_t$。 - 独立公式用$$...$$,单独成段。 响应结构: 1. 引言:什么是Focal Loss,为什么重要。 2. 原理:详细解释公式及其组成部分。 3. 实现:伪代码或Python代码。 4. 总结:简要提及效果。 5. 相关问题。 代码示例:由于用户提到实现细节,我需要提供代码。引用[5]提到简单代码实现,我可以基于此。 从引用[5]:"focal loss原理及简单代码实现",但没有给出代码。引用[4]有伪公式描述。我可以提供一个标准的PyTorch实现。 典型Focal Loss实现(分类): ```python import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, alpha=None, gamma=2, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') pt = torch.exp(-BCE_loss) # pt = p if y=1, else 1-p F_loss = (1 - pt)**self.gamma * BCE_loss if self.alpha is not None: alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) F_loss = alpha_t * F_loss if self.reduction == 'mean': return torch.mean(F_loss) elif self.reduction == 'sum': return torch.sum(F_loss) else: return F_loss ``` 解释代码中的关键点。 确保响应真实可靠:基于提供的引用和标准知识。 </think> ### Focal Loss原理及实现细节详解 Focal Loss 是一种用于解决机器学习中类别不平衡问题的损失函数,最初由何恺明等人在2017年提出,应用于目标检测任务(如RetinaNet)以提升难分类样本的学习效果[^3]。它通过调整交叉熵损失,降低易分类样本的权重,使模型更专注于难分类样本。以下我将从原理和实现细节两方面详细讲解。 #### 一、Focal Loss原理 Focal Loss 建立在标准交叉熵损失的基础上,针对类别不平衡问题引入两个关键机制:**调制因子**和**权重系数**。 1. **标准交叉熵损失回顾** 在分类问题中,标准交叉熵损失定义为: $$ CE(p, y) = - \left[ y \log(p) + (1 - y) \log(1 - p) \right] $$ 其中: - $y$ 是真实标签(0 或 1), - $p$ 是模型预测的概率(范围 [0,1]), - $p_t$ 是简化表达:当 $y=1$ 时 $p_t = p$;当 $y=0$ 时 $p_t = 1 - p$。因此,$p_t$ 表示样本正确分类的置信度(越大,样本越易分类)[^4][^5]。 交叉熵损失的缺点是:所有样本权重相同,当数据中负样本远超正样本(如目标检测中背景类占多数)时,模型优化会偏向负样本;同时,易分类样本主导训练过程,难分类样本被忽略[^2]。 2. **Focal Loss 的核心公式** Focal Loss交叉熵基础上添加调制因子 $(1 - p_t)^\gamma$,降低易分类样本的权重: $$ FL(p_t) = - (1 - p_t)^\gamma \log(p_t) $$ 其中: - $\gamma$ (gamma) 是焦点参数(focusing parameter),通常 $\gamma \geq 0$(默认 $\gamma=2$)。 - 调制因子 $(1 - p_t)^\gamma$ 的作用: - 当 $p_t \to 1$(样本易分类),$(1 - p_t)^\gamma \to 0$,损失权重小,对总损失贡献小。 - 当 $p_t \to 0$(样本难分类),$(1 - p_t)^\gamma \to 1$,损失权重大,对总损失贡献大。 这使得模型更关注难分类样本(例如置信度在 0.5 左右的样本),优化决策边界[^2][^4]。 3. **引入权重系数 $\alpha$** 为解决正负样本数量不平衡问题,Focal Loss 可添加类别权重 $\alpha_t$: $$ FL(p_t) = - \alpha_t (1 - p_t)^\gamma \log(p_t) $$ - $\alpha_t$ 针对类别设置:当 $y=1$ 时 $\alpha_t = \alpha$(正样本权重),当 $y=0$ 时 $\alpha_t = 1 - \alpha$(负样本权重)。 - $\alpha$ 是超参数(范围 [0,1]),通常 $\alpha=0.25$ 用于增强少数类的权重[^4][^5]。 - 最终损失函数中,难分类样本主导优化过程,解决类别不平衡问题[^3]。 4. **效果与优势** - 实验表明,Focal Loss 能显著提升模型性能。例如,在物体检测中,AP(平均精度)从 31.1 提升到 34.0($\gamma=2$ 时)[^1]。 - 优于传统加权交叉熵,因为它同时处理了样本难度和数量不平衡[^3][^5]。 #### Focal Loss的实现细节 实现 Focal Loss 时,需注意数学推导、代码转换和参数调整。以下以 PyTorch 框架为例,分步说明实现细节。 1. **数学基础** - 定义 $p_t$:使用 sigmoid 激活函数将模型输出 logits 转换为概率 $p = \sigma(\text{logits})$。 - 计算基础交叉熵:$ \text{BCE} = - \left[ y \log(p) + (1 - y) \log(1 - p) \right] $。 - 结合调制因子:$ \text{FL} = (1 - p_t)^\gamma \times \text{BCE} $。 - 添加权重:$ \text{FL} = \alpha_t \times (1 - p_t)^\gamma \times \text{BCE} $。 其中 $p_t$ 的计算:$ p_t = p \cdot y + (1 - p) \cdot (1 - y) $,确保 $p_t$ 反映正确分类的概率[^4][^5]。 2. **Python 代码实现** 以下是 PyTorch 中的完整实现,包含注释解释关键步骤: ```python import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, alpha=0.25, gamma=2, reduction='mean'): """ 初始化参数: - alpha: 类别权重系数,控制正负样本平衡(默认0.25,适用于正样本较少的情况) - gamma: 焦点参数,控制难易样本权重(默认2,越大,难样本权重越高) - reduction: 损失聚合方式('mean'、'sum' 或 'none') """ super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): """ 输入: - inputs: 模型输出 logits(未归一化的概率),形状 [batch_size, ...] - targets: 真实标签,与 inputs 形状相同,在 [0,1] 之间 """ # 计算概率 p 通过 sigmoid 激活 p = torch.sigmoid(inputs) # 计算 p_t:正确分类的概率 p_t = p * targets + (1 - p) * (1 - targets) # 当 targets=1 时 p_t=p;targets=0 时 p_t=1-p # 计算基础交叉熵损失 (BCE) bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') # 应用调制因子 (1 - p_t)^gamma modulating_factor = torch.pow(1 - p_t, self.gamma) focal_loss = modulating_factor * bce_loss # 添加类别权重 alpha_t if self.alpha is not None: alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) # 正样本权重 alpha,负样本权重 1-alpha focal_loss = alpha_t * focal_loss # 聚合损失(默认取均) if self.reduction == 'mean': return torch.mean(focal_loss) elif self.reduction == 'sum': return torch.sum(focal_loss) else: return focal_loss # 使用示例 if __name__ == "__main__": # 模拟数据:batch_size=2, 分类任务 inputs = torch.tensor([0.8, -0.5]) # 模型输出 logits targets = torch.tensor([1.0, 0.0]) # 真实标签 # 计算 Focal Loss focal_loss_fn = FocalLoss(alpha=0.25, gamma=2) loss = focal_loss_fn(inputs, targets) print("Focal Loss:", loss.item()) # 输出损失 ``` 3. **实现关键点** - **概率转换**:使用 `torch.sigmoid` 确保输入 logits 转换为概率 $p \in [0,1]$。 - **数稳定**:避免 $\log(0)$ 错误,PyTorch 的 `binary_cross_entropy_with_logits` 内置了稳定性处理。 - **参数调整**: - $\gamma$:增大 $\gamma$ 增强难样本权重($\gamma=0$ 时退化为标准交叉熵)。 - $\alpha$:正样本比例低时,设置 $\alpha < 0.5$。 - **多分类扩展**:Focal Loss 可扩展到多分类(使用 softmax),公式调整为 $ FL = - \alpha_c (1 - p_c)^\gamma \log(p_c) $,其中 $p_c$ 是正确类别的概率[^3][^4]。 4. **训练建议** - 初始化 $\gamma=2$, $\alpha=0.25$,根据验证集效果调整。 - 结合学习率调度(如余弦退火),避免难样本过拟合。 - 在类别不平衡任务(如目标检测、医学图像分割)中应用广泛[^1][^3]。 #### 总结 Focal Loss 通过调制因子 $(1 - p_t)^\gamma$ 和权重系数 $\alpha_t$,有效解决了类别不平衡问题,尤其提升了难分类样本的学习效率。其在 RetinaNet 等模型中实现了显著性能提升(如 AP 提高 3%)[^1]。实现时,需注意概率转换和超参数调优,以平衡训练过程。
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值