Loss【1】:Focal Loss

系列文章目录



前言

类别不平衡是一个在目标检测领域被广泛讨论的问题,因为目标数量的多少在数据集中能很直观的体现。同时,在分割中这也是一个值得关注的问题,毕竟分割的本质是对像素进行分类。而处理类别不平衡一个非差常用的方法就是通过 Focal Loss 来引导模型更关注困难的类。


1. 什么是 Focal Loss

Focal Loss 是在标准交叉熵损失基础上修改得到的。相比 CrossEntropy Loss 它增加了容易和难分样本的权重,对于难分的样本增加权重,增加 loss 的贡献度;减少易分类样本的权重,使得模型在训练时更专注于难分类的样本。

Focal Loss 从另外的视角来解决样本不平衡问题,那就是根据置信度动态调整 CE Loss,当预测正确的置信度增加时,loss 的权重系数会逐渐衰减至0,这样模型训练的 loss 更关注难例,而大量容易的例子其 loss 贡献很低。

比如假如一张图片上有 10 个正样本,每个正样本的损失值是 3,那么这些正样本的总损失是 10x3=30。而假如该图片上有 10000 个简单易分负样本,尽管每个负样本的损失值很小,假设是 0.1,那么这些简单易分负样本的总损失是 10000x0.1=1000,那么损失值要远远高于正样本的损失值。所以如果在训练的过程中使用全部的正负样本,那么它的训练效果会很差。


2. 逐过程解析 Focal Loss

  1. 公式一览:
    在这里插入图片描述
    • α \alpha α 侧重的是正负样本之间的不平衡,一般设置为 0.25
    • γ \gamma γ 难易样本上的权重调节,一般设置为 2
    • 简单的加权 CE Loss 可能只能实现正负样本之间不平衡的调节,所以对于大多数不平衡任务来说 Focal Loss 应该还是能起到更好的效果
  2. 首先看一下二分类交叉熵损失函数
    在这里插入图片描述
    在这里插入图片描述
  3. 二分类交叉熵损失函数: y y y 是样本的标签值,而 p p p 是模型预测某一个样本为正样本的概率,对于真实标签为正样本的样本,它的概率 p p p 越大说明模型预测的越准确,对于真实标签为负样本的样本,它的概率 p p p 越小说明模型预测的越准确
  4. 如果我们定义 p t p_t pt 为如下的形式
    在这里插入图片描述
  5. 公式 (1) 可以修改为如下形式 (2)
    在这里插入图片描述
  6. 现在我们定义一个参数 α \alpha α 1 − α 1 - \alpha 1α 来平衡正负样本的权重,定义 α t \alpha_t αt 如下,需要注意的是, α \alpha α 是个超参数用来平衡正负样本的权重,并不是实际的正负样本的比例,
    在这里插入图片描述
  7. 公式 (2) 可以修改为如下形式 (3)
    在这里插入图片描述
  8. 又因为样本有难易之分,所以我们必须要能区分出困难样本和简单样本,所以我们设置一个系数 ( 1 − p t ) γ ( 1-p_t )^{\gamma} (1pt)γ
  9. 它可以降低简单样本的损失贡献,而使得训练时更重视一些困难样本,Focal Loss 可以定义为:
    在这里插入图片描述
  10. 看一些权重计算的例子:
    在这里插入图片描述
    • 如果预测正样本概率是 0.95(即对于一个真实标签为正样本的样本,使用模型预测它也是正样本的概率是 0.95),这显然是一个简单的样本
    • 如果预测正样本概率是 0.5 ,这显然是一个稍微困难一定的样本
    • 如果预测负样本的概率为 0.9(即对于一个真实标签为负样本的样本,使用模型预测它是正样本的概率是 0.9),这显然是一个困难的样本,则该样本的难易权重是
    • 如果预测负样本的概率为 0.1(即对于一个真实标签为负样本的样本,使用模型预测它是正样本的概率是 0.1),这显然是一个简单的样本,
  11. 为此,我们得到最终的 Focal Loss
    在这里插入图片描述

3. Focal Loss 的 PyTorch 实现

首先感谢上海 AI Lab 的杰出工作,SAM-Med2D
我这里的实现来自仓库:SAM-Med2D
如果能对大家有帮助,希望后期大家不要忘记引用这个工作:

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, pred, mask):
        """
        pred: [B, 1, H, W]
        mask: [B, 1, H, W]
        """
        assert pred.shape == mask.shape, "pred and mask should have the same shape."
        p = torch.sigmoid(pred)
        num_pos = torch.sum(mask)
        num_neg = mask.numel() - num_pos
        w_pos = (1 - p) ** self.gamma
        w_neg = p ** self.gamma

        loss_pos = -self.alpha * mask * w_pos * torch.log(p + 1e-12)
        loss_neg = -(1 - self.alpha) * (1 - mask) * w_neg * torch.log(1 - p + 1e-12)

        loss = (torch.sum(loss_pos) + torch.sum(loss_neg)) / (num_pos + num_neg + 1e-12)

        return loss


总结

参考链接:
深入剖析Focal loss损失函数

  • 13
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

zzzyzh

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值