Focal loss及其实现

Focal Loss是一种改进的交叉熵损失函数,旨在解决目标检测中正负样本不平衡的问题。该损失函数通过引入调制系数减少易分类样本的权重,使模型更关注难分类样本。适用于处理类别不平衡的数据集。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Focal loss 出自ICCV2017 RBG和Kaiming大神的论文 Focal Loss for Dense Object Detection

 对标准的交叉熵损失做了改进,效果如下图所示。

标准的交叉熵损失函数见:loss函数之NLLLoss,CrossEntropyLoss_ltochange的博客-CSDN博客_nll函数

图中,横坐标为p_t,代表样本实际类别的预测概率, p_t越大,代表样本越容易进行分类,纵坐标为loss。

通过引入调制系数

(1-p_t)^\gamma减少loss中易分类样本的权重从而使得模型在训练时更专注于难分类的样本。

具体来说:

  1. 当一个样本被分错的时候(难分类样本),p_t很小,1-p_t 接近1,loss不被影响;
  2. p_t趋向于1(易分类样本),1-p_t 接近0,调制系数降低,对loss的贡献减小。
  3. \gamma增加的时候,调制系数也会增加。 参数\gamma平滑地调节了易分样本调低权值的比例。实验发现\gamma =2最好。
  4. 直觉上来说,当\gamma一定的时候,比如\gamma =2,易分类样本p_t=0.9的loss要比标准的交叉熵loss小100+倍,当p_t=0.968时,要小1000+倍,但是对于难分类样本p_t<0.5,loss最多小了4倍。因此,难分类样本的权重相对就提升了很多。

Focal loss最后使用的公式为:

\operatorname{FL}\left(p_{\mathrm{t}}\right)=-\alpha_{\mathrm{t}}\left(1-p_{\mathrm{t}}\right)^{\gamma} \log \left(p_{\mathrm{t}}\right)

其中,{\alpha }_t 用于控制正负样本的权重,处理样本不均衡问题(pytorch中已有实现)。

(1-p_t)^\gamma用于控制难易样本的权重,使得模型更关注难样本。

{\alpha }_t=1向量1,维度为类别数大小),\gamma =0时,即为标准交叉熵损失函数

论文实验如下图:

但是focal loss从公式上看只能用于二分类吧!对于多分类,例如自然语言处理中预测单词,可能是上万分类,即使模型训练得很好,尽管pt在所有概率中最大,但是和1还是相差比较多的,1-pt 一般情况下都是特别大的。这时可能用focal loss就不是很合适

后期补上代码,并在nlp领域尝试

pytorch实现(来自GitHub - lonePatient/TorchBlocks: A PyTorch-based toolkit for natural language processing

# coding: utf-8

import torch


class FocalLoss(nn.Module):

    def __init__(self, num_labels, activation_type='softmax', gamma=2.0, alpha=0.25, epsilon=1.e-9):

        super(FocalLoss, self).__init__()
        self.num_labels = num_labels
        self.gamma = gamma
        self.alpha = alpha
        self.epsilon = epsilon
        self.activation_type = activation_type

    def forward(self, input, target):
        """
        Args:
            logits: model's output, shape of [batch_size, num_cls]
            target: ground truth labels, shape of [batch_size]
        Returns:
            shape of [batch_size]
        """
        if self.activation_type == 'softmax':
            idx = target.view(-1, 1).long()
            one_hot_key = torch.zeros(idx.size(0), self.num_labels, dtype=torch.float32, device=idx.device)
            one_hot_key = one_hot_key.scatter_(1, idx, 1)
            logits = torch.softmax(input, dim=-1)
            loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
            loss = loss.sum(1)
        elif self.activation_type == 'sigmoid':
            multi_hot_key = target
            logits = torch.sigmoid(input)
            zero_hot_key = 1 - multi_hot_key
            loss = -self.alpha * multi_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
            loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits, self.gamma) * (1 - logits + self.epsilon).log()
        return loss.mean()

参考:

Focal loss论文详解 - 知乎

Focal Loss 的Pytorch 实现以及实验 - 知乎

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

旺旺棒棒冰

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值