FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现

1. FocalLoss的应用场景

学一个东西,首先要知道这个东西是干嘛用的。

FocalLoss主要有两个作用,这也决定了它的应用场景:

  1. FocalLoss可以调节正负样本的loss权重。这意味着,当正负样本数量及其不平衡时,可以考虑使用FocalLoss。
  2. FocalLoss可以调节难易样本的loss权重。这意味着,当训练样本的难易程度不平衡时,可以考虑使用FocalLoss

这也是“Focal Loss”的名字的含义,把目光聚焦(Focal)在那些“少的,难的”样本上。

虽然大部分博客讨论FocalLoss都是在目标检测场景下,但其实FocalLoss其他场景下都可以用。

举个NLP的应用场景:

  1. 当我们在情感分类(好评/差评)时,若99%都是好评,只有1%是差评,就可以考虑使用FocalLoss通过loss来调节数据不平衡问题。
  2. 情感分类问题有些样本很难,例如:“我家狗吃了你的菜连夜给我做了四菜一汤”。而有些样本很简单,例如“差评,太难吃了”。这种场景下,FocalLoss可以帮助调节难易样本的loss权重,从而更好的学习到难样本的特征。

2. 二分类场景下FocalLoss原理解释

本节会分别讨论FocalLoss是如何实现其两个功能的,然后再进行整合。

2.1 FocalLoss如何调节正负样本权重

二分类问题我们通常使用交叉熵计算Loss,损失函数如下:

C E ( p , y ) = { − log ⁡ ( p )  if  y = 1 − log ⁡ ( 1 − p )  if  y = 0 \mathrm{CE}(p, y)= \begin{cases}-\log (p) & \text { if } y=1 \\ -\log (1-p) & \text { if } y=0\end{cases} CE(p,y)={log(p)log(1p) if y=1 if y=0

其中CE是CrossEntropy的缩写, p p p 是预测结果,例如0.8。 y y y 是标签值。

假设我们99%的样本都是负样本,那么最终计算出的loss负样本占比极大。要进行调节,很简单,只需要乘个权重就行了。比如:

C E ( p , y ) = { − log ⁡ ( p ) ∗ 0.9  if  y = 1 − log ⁡ ( 1 − p ) ∗ 0.1  if  y = 0 \mathrm{CE}(p, y)= \begin{cases}-\log (p) * 0.9 & \text { if } y=1 \\ -\log (1-p) * 0.1 & \text { if } y=0\end{cases} CE(p,y)={log(p)0.9log(1p)0.1 if y=1 if y=0

我们让正样本和负样本的loss给个9:1的权重就行了。将其0.9写成变量 α \alpha α 为:

C E ( p , y , α ) = { − log ⁡ ( p ) ∗ α  if  y = 1 − log ⁡ ( 1 − p ) ∗ ( 1 − α )  if  y = 0 \mathrm{CE}(p, y, \alpha)= \begin{cases}-\log (p) * \alpha& \text { if } y=1 \\ -\log (1-p) * (1-\alpha) & \text { if } y=0\end{cases} CE(p,y,α)={log(p)αlog(1p)(1α) if y=1 if y=0

其中, α ∈ ( 0 , 1 ) \alpha \in (0,1) α(0,1) 为超参数。这就是FocalLoss调节正负样本权重的方式,简单吧。

2.2 FocalLoss如何调节难易样本权重

当我们在训练二分类问题时,经过sigmoid后最终的输出是0到1的概率,表示为正样本的概率是多少。

那假设标签为1的样本:

  • 若预测为为0.95,意味该样本是一个比较简单的样本。
  • 若预测值为0.65,意味着该样本稍微有点难
  • 若预测值为0.28,意味着该样本非常难。

负样本同理。即 预测值距离真值越远,则样本越难

难样本想要多学习,那就给它的loss分个较大的权重,简单样本易学习,那就给个较小的权重。那我们可以直接用它的难易程度给它分权重嘛,例如:

假设标签为1的样本:

  • 若预测为为0.95,意味该样本是一个比较简单的样本。权重为 (1-0.95) = 0.05
  • 若预测值为0.65,意味着该样本稍微有点难。权重为 (1-0.65) = 0.35
  • 若预测值为0.28,意味着该样本非常难。权重为 (1-0.28) = 0.72

按照这个思路,我们就可以得到如下损失函数:

C E ( p , y ) = { − log ⁡ ( p ) ∗ ( 1 − p )  if  y = 1 − log ⁡ ( 1 − p ) ∗ p  if  y = 0 \mathrm{CE}(p, y)= \begin{cases}-\log (p) * (1-p) & \text { if } y=1 \\ -\log (1-p) * p & \text { if } y=0\end{cases} CE(p,y)={log(p)(1p)log(1p)p if y=1 if y=0

这样你可能还不过瘾,你想让简单样本权重更低,难样本权重更高,那么也很简单,只需要加个平方就行了,这样小的会更小,大的会更大。这样我们会得到如下公式:

C E ( p , y ) = { − log ⁡ ( p ) ∗ ( 1 − p ) 2  if  y = 1 − log ⁡ ( 1 − p ) ∗ p 2  if  y = 0 \mathrm{CE}(p, y)= \begin{cases}-\log (p) * (1-p)^2 & \text { if } y=1 \\ -\log (1-p) * p^2 & \text { if } y=0\end{cases} CE(p,y)={log(p)(1p)2log(1p)p2 if y=1 if y=0

但你可能会觉得平方太小或太大,那么我们把平方写成超参数 γ \gamma γ,此时公式就变成了如下:

C E ( p , y ) = { − log ⁡ ( p ) ∗ ( 1 − p ) γ  if  y = 1 − log ⁡ ( 1 − p ) ∗ p γ  if  y = 0 \mathrm{CE}(p, y)= \begin{cases}-\log (p) * (1-p)^\gamma & \text { if } y=1 \\ -\log (1-p) * p^\gamma & \text { if } y=0\end{cases} CE(p,y)={log(p)(1p)γlog(1p)pγ if y=1 if y=0

这样我们就完成了难易样本权重的调节。最后再总结一下参数 γ \gamma γ

  • γ = 0 \gamma=0 γ=0 时,损失函数退化成了原始的CrossEntropy。
  • γ \gamma γ 越大,对权重的调节就越狠,反之则越轻。越狠的意思是:容易样本的权重会越低,难样本的权重会越高。
  • γ \gamma γ 通常取 2 2 2

2.3 整合上述过程,完成FocalLoss

整合过程很简单,把 α \alpha α γ \gamma γ 一块放到公式里就行了,FocalLoss公式如下:

F L ( p , y , α , γ ) = { − log ⁡ ( p ) ∗ α ∗ ( 1 − p ) γ  if  y = 1 − log ⁡ ( 1 − p ) ∗ ( 1 − α ) ∗ p γ  if  y = 0 \mathrm{FL}(p, y, \alpha, \gamma)= \begin{cases}-\log (p) *\alpha * (1-p)^\gamma & \text { if } y=1 \\ -\log (1-p) *(1-\alpha)* p^\gamma & \text { if } y=0\end{cases} FL(p,y,α,γ)={log(p)α(1p)γlog(1p)(1α)pγ if y=1 if y=0

这样写稍显难看,所以我们定义两个新的变量 α t \alpha_t αt p t p_t pt, 其中:

α t = { α  if  y = 1 1 − α  if  y = 0 ,        p t = { p  if  y = 1 1 − p  if  y = 0 \alpha_t= \begin{cases} \alpha & \text { if } y=1 \\ 1-\alpha & \text { if } y=0\end{cases}, ~~~~~~p_t= \begin{cases} p & \text { if } y=1 \\ 1-p & \text { if } y=0\end{cases} αt={α1α if y=1 if y=0,      pt={p1p if y=1 if y=0

那么FocalLoss就可以写成如下的最终公式:

F L ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) \mathrm{FL}(p_t) = -\alpha_t(1-p_t)^\gamma \log (p_t) FL(pt)=αt(1pt)γlog(pt)

这就是FocalLoss的公式。

2.4 Pytorch 实现FocalLoss

import torch
from torch import nn


class BinaryFocalLoss(nn.Module):
    """
    参考 https://github.com/lonePatient/TorchBlocks
    """

    def __init__(self, gamma=2.0, alpha=0.25, epsilon=1.e-9):
        super(BinaryFocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.epsilon = epsilon

    def forward(self, input, target):
        """
        Args:
            input: model's output, shape of [batch_size, num_cls]
            target: ground truth labels, shape of [batch_size]
        Returns:
            shape of [batch_size]
        """
        multi_hot_key = target
        logits = input
        # 如果模型没有做sigmoid的话,这里需要加上
        # logits = torch.sigmoid(logits)
        zero_hot_key = 1 - multi_hot_key
        loss = -self.alpha * multi_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
        loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits, self.gamma) * (1 - logits + self.epsilon).log()
        return loss.mean()


if __name__ == '__main__':
    m = nn.Sigmoid()
    loss = BinaryFocalLoss()
    input = torch.randn(3, requires_grad=True)
    target = torch.empty(3).random_(2)
    output = loss(m(input), target)
    print("loss:", output)
    output.backward()

3. 多分类场景下的FocalLoss

有了前面二分类的基础,多分类就影刃而解了。

3.1 FocalLoss调节多分类的类别权重

假设我们有个三分类的场景,y=(1, 2, 3),他们的样本数量分别是100个,2000个和10000个。

那么此时我们的 α t \alpha_t αt 就不能再是一个数了,而是要是一个list,例如:

α t = { 0.7  if  y = 1 0.25  if  y = 2 0.05  if  y = 3 \alpha_t= \begin{cases} 0.7 & \text { if } y=1 \\ 0.25 & \text { if } y= 2 \\ 0.05 & \text { if } y= 3 \end{cases} αt= 0.70.250.05 if y=1 if y=2 if y=3

在多分类场景下,我们的 α \alpha α 超参由一个数变成了一个数组,我们要为每一个类别指定权重。样本多的权重低一点,样本少的权重高一点。写成具体公式则为:

α t = { α 1  if  y = 1 α 2  if  y = 2 . . .  if  y = . . . α n  if  y = n \alpha_t= \begin{cases} \alpha_1 & \text { if } y=1 \\ \alpha_2 & \text { if } y= 2 \\ ... & \text { if } y= ... \\ \alpha_n & \text { if } y= n \end{cases} αt= α1α2...αn if y=1 if y=2 if y=... if y=n

其中 n n n 表示有 n n n 个类别。

在大部分博客甚至开源项目上,在多分类问题上 α \alpha α 也还是一个数。但在多分类问题上,这样理论上没办法解决调整不平衡数据的问题,相当于所有的数据都乘以了一个小数,没效果。 上面的理解是我在向ChatGPT提问时它给出的答案,我觉得这个是比较合理的。

3.2 FocalLoss调节多分类难易样本权重

同样,假设我们有个三分类的场景,y=(1, 2, 3),对于某个样本的预测结果如下:

p = [ 0.85 , 0.1 , 0.05 ] T p= [0.85, 0.1, 0.05]^T p=[0.85,0.1,0.05]T

  • 若标签为1,那么构造难易程度的调制因子时只需要考虑 0.85 0.85 0.85 即可,同时也说明这个样本是一个简单样本
  • 若标签为2,同理,调制因子只需要考虑 0.1 0.1 0.1,同时说明这个样本很难。
  • 若标签为3,同理。

为了达到上述目的,我们可以使用one-hot向量来把不关心的非标签值给抹去,即:

[ 1 0 0 ] × [ 0.85 0.1 0.05 ] = [ 0.85 0 0 ] \begin{bmatrix} 1 \\ 0 \\ 0 \end{bmatrix} \times \begin{bmatrix} 0.85 \\ 0.1 \\ 0.05 \end{bmatrix} = \begin{bmatrix} 0.85 \\ 0 \\ 0 \end{bmatrix} 100 × 0.850.10.05 = 0.8500

此时,我们把其与调制因子结合,为 h ∗ ( 1 − p t ) γ h *(1-p_t)^\gamma h(1pt)γ,为:

[ 1 0 0 ] × ( 1 − [ 0.85 0.1 0.05 ] ) γ = [ 0.1 5 γ 0 0 ] \begin{bmatrix} 1 \\ 0 \\ 0 \end{bmatrix} \times(1- \begin{bmatrix} 0.85 \\ 0.1 \\ 0.05 \end{bmatrix})^\gamma = \begin{bmatrix} 0.15^\gamma \\ 0 \\ 0 \end{bmatrix} 100 ×(1 0.850.10.05 )γ= 0.15γ00

这里,我们将 one-hot 向量用 h h h 表示。这里的 0.1 5 γ 0.15^\gamma 0.15γ 就是我们要给该样本附加的权重。

3.3 整合上述过程,完成多分类的FocalLoss

综上所述,在多分类场景下,FocalLoss的公式变成了如下:

F L ( p t ) = ∑ − α t ∗ h ∗ ( 1 − p t ) γ log ⁡ ( p t ) \mathrm{FL}(p_t) = \sum -\alpha_t*h*(1-p_t)^\gamma \log (p_t) FL(pt)=αth(1pt)γlog(pt)

这里, α t \alpha_t αt p t p_t pt 的含义与二分类不同, α t \alpha_t αt 为一个列表,里面是每个类别的权重。而 p t p_t pt 是输出的概率分布, h h h 是当前样本的one-hot向量。

举个实际的例子来看一下该公式:

假设我们有个三分类的场景,y=(1, 2, 3),其中 α t = [ 0.7 , 0.25 , 0.05 ] T \alpha_t=[0.7, 0.25, 0.05]^T αt=[0.7,0.25,0.05]T γ = 2 \gamma=2 γ=2。对于样本 y = 1 y=1 y=1,输出的概率分布为 p = [ 0.85 , 0.1 , 0.05 ] t p=[0.85,0.1,0.05]^t p=[0.85,0.1,0.05]t,,则FocalLoss为:

F L l o s s = sum ( − [ 0.7 0.25 0.05 ] × [ 1 0 0 ] × ( 1 − [ 0.85 0.1 0.05 ] ) 2 × log ⁡ ( [ 0.85 0.1 0.05 ] ) ) = sum ( − [ 0.7 0.25 0.05 ] × [ 0.1 5 2 0 0 ] × log ⁡ ( [ 0.85 0.1 0.05 ] ) ) = sum ( [ 0.0026 0 0 ] ) = 0.0026 \begin{aligned} \mathrm{FL_{loss}} = & \text{sum} (-\begin{bmatrix} 0.7 \\ 0.25 \\ 0.05 \end{bmatrix}\times \begin{bmatrix} 1 \\ 0 \\ 0 \end{bmatrix}\times (1- \begin{bmatrix} 0.85 \\ 0.1 \\ 0.05 \end{bmatrix})^2 \times \log(\begin{bmatrix} 0.85 \\ 0.1 \\ 0.05 \end{bmatrix})) \\ \\ = & \text{sum} (-\begin{bmatrix} 0.7 \\ 0.25 \\ 0.05 \end{bmatrix}\times \begin{bmatrix} 0.15^2 \\ 0 \\ 0 \end{bmatrix} \times \log(\begin{bmatrix} 0.85 \\ 0.1 \\ 0.05 \end{bmatrix})) \\ \\ = & \text{sum}(\begin{bmatrix} 0.0026 \\ 0 \\ 0 \end{bmatrix}) \\ \\ =&0.0026 \end{aligned} FLloss====sum( 0.70.250.05 × 100 ×(1 0.850.10.05 )2×log( 0.850.10.05 ))sum( 0.70.250.05 × 0.15200 ×log( 0.850.10.05 ))sum( 0.002600 )0.0026

从上面例子可以看出,因为one-hot的存在,真正对loss起作用的其实只有样本所在的那一行。

因此,我们可以将FocalLoss公式改进为如下:

F L = − α c ( 1 − p c ) γ log ⁡ ( p c ) \mathrm{FL} = -\alpha_c(1-p_c)^\gamma \log(p_c) FL=αc(1pc)γlog(pc)

其中 c c c 为当前样本的类别, α c \alpha_c αc 表示类别c对应的权重, p c p_c pc 表示输出概率分布对于类别 c 的概率值。

3.4 Pytorch 实现多分类FocalLoss

class FocalLoss(nn.Module):
    """
    参考 https://github.com/lonePatient/TorchBlocks
    """

    def __init__(self, gamma=2.0, alpha=1, epsilon=1.e-9, device=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        if isinstance(alpha, list):
            self.alpha = torch.Tensor(alpha, device=device)
        else:
            self.alpha = alpha
        self.epsilon = epsilon

    def forward(self, input, target):
        """
        Args:
            input: model's output, shape of [batch_size, num_cls]
            target: ground truth labels, shape of [batch_size]
        Returns:
            shape of [batch_size]
        """
        num_labels = input.size(-1)
        idx = target.view(-1, 1).long()
        one_hot_key = torch.zeros(idx.size(0), num_labels, dtype=torch.float32, device=idx.device)
        one_hot_key = one_hot_key.scatter_(1, idx, 1)
        one_hot_key[:, 0] = 0  # ignore 0 index.
        logits = torch.softmax(input, dim=-1)
        loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
        loss = loss.sum(1)
        return loss.mean()


if __name__ == '__main__':
    loss = FocalLoss(alpha=[0.1, 0.2, 0.3, 0.15, 0.25])
    input = torch.randn(3, 5, requires_grad=True)
    target = torch.empty(3, dtype=torch.long).random_(5)
    output = loss(input, target)
    print(output)
    output.backward()
  • 30
    点赞
  • 105
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 18
    评论
下面是 circle loss 和 focal loss 的简单实现(包括二分类多分类): Circle Loss: ```python import torch import torch.nn as nn import torch.nn.functional as F class CircleLoss(nn.Module): def __init__(self, m=0.25, s=30): super(CircleLoss, self).__init__() self.m = m self.s = s def forward(self, feats, labels): sim_mat = torch.matmul(feats, feats.t()) mask = labels.expand(labels.size(0), labels.size(0)).t().eq(labels.expand(labels.size(0), labels.size(0))) pos_mask = mask.triu(diagonal=1) neg_mask = mask.logical_not().triu(diagonal=1) pos_sim = sim_mat[pos_mask] neg_sim = sim_mat[neg_mask] alpha_p = F.relu(-pos_sim.detach() + 1 + self.m) alpha_n = F.relu(neg_sim.detach() + self.m) delta_p = 1 - self.m delta_n = self.m logit_p = -self.s * alpha_p * (pos_sim - delta_p) logit_n = self.s * alpha_n * (neg_sim - delta_n) logit = torch.cat([logit_p, logit_n], dim=0) loss = F.softplus(torch.logsumexp(logit, dim=0)) return loss ``` Focal Loss: ```python import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, alpha=0.25, gamma=2, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss if self.reduction == 'mean': return torch.mean(F_loss) elif self.reduction == 'sum': return torch.sum(F_loss) else: return F_loss ``` 以上代码适用于 PyTorch 深度学习框架。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

iioSnail

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值