一文读懂Focal Loss及Pytorch代码(详细注释)

广告位:

图像拼接论文精读专栏 —— 图像拼接领域论文全覆盖(包含数据集),省时省力读论文,带你理解晦涩难懂的论文算法,学习零散的知识和数学原理,并学会写图像拼接领域的论文(介绍、相关工作、算法、实验、结论、并附有参考文献,不用一篇一篇文章再找)

图像拼接论文源码精读专栏 —— 图像拼接有源码的论文全覆盖(有的自己复现),帮助你通过源码进一步理解论文算法,助你做实验,跑出拼接结果,得到评价指标RMSE、SSIM、PSNR等,并寻找潜在创新点和改进提升思路。

超分辨率重建专栏 —— 从SRCNN开始,带你读论文,写代码,复现结果,找创新点,完成论文。手把手教,保姆级攻略。帮助你顺利毕业,熟练掌握超分技术。

有需要的同学可以点上面链接看看。



前言

Focal Loss及RetinaNet原理见另一篇文章:【论文精读】Focal Loss for Dense Object Detection(RetinaNet)全文翻译及重点总结

本文介绍Focal Loss以及其Pytorch实现。


Focal Loss详解

直接上公式:
F L ( p t ) = − α t ( 1 − p t ) γ l o g ( p t ) FL(p_t)=-\alpha_t(1-p_t)^\gamma log(p_t) FL(pt)=αt(1pt)γlog(pt)
其中:
p t = { p , i f   y = 1 1 − p , o t h e r w i s e p_t= \begin{cases} p,\quad & if \ y=1 \\ 1-p, \quad & otherwise \end{cases} pt={p,1p,if y=1otherwise
交叉熵损失:
C E ( p t ) = − l o g ( p t ) CE(p_t) = -log(p_t) CE(pt)=log(pt)

  • p t p_t pt是类别(t个不同类别)概率(多分类就是softmax的结果),衡量样本难易程度,如果 p t p_t pt较大则是简单样本,较小则是困难样本;
  • α t \alpha_t αt是调节类别权重因子,它的值为第一类正样本权重,RetinaNet中设置为0.25,由于正样本远少于负样本,所以这样设置,让正样本的权重低,负样本权重为0.75;
  • γ \gamma γ是调节难易样本的权重因子,让模型快速关注困难样本。RetinaNet中设为2;
  • ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ整体控制损失大小,因为 γ = 2 \gamma=2 γ=2,所以 p t → 0 p_t \rightarrow 0 pt0时,loss大; p t → 1 p_t \rightarrow 1 pt1时,loss小;这样合理的权重分配可以让模型更好的学习训练;
  • α = 1 , γ = 0 \alpha = 1,\gamma = 0 α=1,γ=0时,FL就变成了CE。所以在复现的时候提供了一个思路,那就是定义前面两个权重因子,再乘上CE就得到了FL, p t p_t pt是softmax的结果,那么 l o g ( p t ) log(p_t) log(pt)就是log_softmax的结果。

实现思路

定义focal_loss类:

1.参数定义:alpha,gamma,num_classes等,应用到目标检测中可能还需要anchors等参数。

2.forward:focalloss是由CE乘以因子组成的,而CE是由NLL和log_softmax组成的。所以,先实现NLL和log_softmax,然后一步一步通过公式实现focal loss,主要是Tensor的变换。

3.输入参数设计:输入张量是[B,N,C]或[B,C],其中B是批量,N是预测框数量,C是类别数;[B,C]就是单纯的分类问题。真实标签是[B,N]或[B],总的标签数,如果是目标检测中则标签数就是B*N。


Focal Loss类的代码

from torch import nn
import torch
from torch.nn import functional as F

class focal_loss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, num_classes = 5, size_average=True):
        """
        focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)
        步骤详细的实现了 focal_loss损失函数.
        :param alpha:   阿尔法α,类别权重. 当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.255
        :param gamma:   伽马γ,难易样本调节参数. retainnet中设置为2
        :param num_classes:     类别数量
        :param size_average:    损失计算方式,默认取均值
        """
        super(focal_loss,self).__init__()
        self.size_average = size_average
        if isinstance(alpha,list):
            assert len(alpha)==num_classes   # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
            print(" --- Focal_loss alpha = {}, 将对每一类权重进行精细化赋值 --- ".format(alpha))
            self.alpha = torch.Tensor(alpha)
        else:
            assert alpha<1   #如果α为一个常数,则降低第一类的影响,在目标检测中为第一类
            print(" --- Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用 --- ".format(alpha))
            self.alpha = torch.zeros(num_classes)
            self.alpha[0] += alpha
            self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]

        self.gamma = gamma

    def forward(self, preds, labels):
        """
        focal_loss损失计算
        :param preds:   预测类别. size:[B,N,C] or [B,C]    分别对应与检测与分类任务, B批次, N检测框数, C类别数
        :param labels:  实际类别. size:[B,N] or [B]        [B*N个标签(假设框中有目标)],[B个标签]
        :return:
        """
                
        #固定类别维度,其余合并(总检测框数或总批次数),preds.size(-1)是最后一个维度
        preds = preds.view(-1,preds.size(-1))
        self.alpha = self.alpha.to(preds.device)
        
        #使用log_softmax解决溢出问题,方便交叉熵计算而不用考虑值域
        preds_logsoft = F.log_softmax(preds, dim=1) 
        
     	#log_softmax是softmax+log运算,那再exp就算回去了变成softmax
        preds_softmax = torch.exp(preds_logsoft)    
   
        # 这部分实现nll_loss ( crossentropy = log_softmax + nll)
        preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) 
        preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))

        self.alpha = self.alpha.gather(0,labels.view(-1))

        # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ
        
        #torch.mul 矩阵对应位置相乘,大小一致
        loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft) 
    
        #torch.t()求转置
        loss = torch.mul(self.alpha, loss.t())
        #print(loss.size()) [1,5]
        
        if self.size_average:
            loss = loss.mean()
        else:
            loss = loss.sum()
       
        return loss

注:根据自己需要修改num_classes的值


具体流程分析代码

先简单一点测试一下图片分类:

假设有三张图片,分类类别一共五类,三张图片的真实类别类别数分别为2,3,4

torch.manual_seed(50)
preds = torch.randn((3,5))
#preds = torch.randn((3,10,5))
print(preds)
preds = preds.view(-1,preds.size(-1))
print(preds.size())

labels = torch.tensor([2,3,4])
#labels = torch.tensor([2,3,4]*10)

print(labels.view(-1))
print(labels.view(-1,1))

看一下各张量值以及size()后面会用到:
在这里插入图片描述
计算log_softmax以及softmax:

preds_logsoft = F.log_softmax(preds, dim=1) # log_softmax
print(preds_logsoft)
preds_softmax = torch.exp(preds_logsoft)
print(preds_softmax)

结果:
在这里插入图片描述
从中取得每个图片分别在2,3,4类别的预测值:如何取得可见另一篇文章详解。

preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
print(preds_logsoft)    
preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) 
print(preds_softmax)   

结果:
在这里插入图片描述
计算loss并转置:gamma取2

a = torch.pow((1-preds_softmax), 2)
loss = -torch.mul(torch.pow((1-preds_softmax), 2), preds_logsoft) 
print(loss.t())
print(loss.t().size())

结果:
在这里插入图片描述
初始化alpha权重,并得到所需label的alpha权重

alpha = torch.tensor([0.25,0.75,0.75,0.75,0.75])
alpha = alpha.gather(0,labels.view(-1))
alpha

结果:
在这里插入图片描述
乘以alpha并求平均得到最终的focal loss

loss = torch.mul(alpha, loss.t())
loss = loss.mean()
loss

得到focal loss为:
在这里插入图片描述

PS:若要加入检测框数量,输入采用代码中注释的部分,本例中是假设每张图片有10个检测框(正样本),则总标签个数就是30个,相应地改变标签张量为30个元素即可。


Focal Loss(FL)与CrossEntropy(CE)对比:

criterion = focal_loss()
loss = criterion(preds, labels)
print("FL   loss",loss)
a=F.cross_entropy(preds, labels)
#a=nn.CrossEntropyLoss()
#a = a(preds,labels)
print('CE   loss',a)

结果说明FL比CE的损失低,效果好
在这里插入图片描述


总结

  1. focal loss在目标检测中预测类别时使用,是分类子网下的损失。
  2. 先log_softmax再取exp得到的softmax各类别之和不是严格为1。
  3. F.cross_entropy和nn.CrossEntropyLoss相同。
  4. gather函数详解见【Pytorch小知识】torch.gather()函数的用法及在Focal Loss中的应用(详细易懂)
  5. 可以根据分类或检测类别中的目标多少来自定义alpha的权重。测试代码见资源(资源链接)

没有硬件条件,需要云服务的同学可以扫码看看:
请添加图片描述

  • 19
    点赞
  • 64
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 18
    评论
以下是 PyTorch 中实现 Focal Loss代码: ```python import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, gamma=2, alpha=None, size_average=True): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha]) if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) self.size_average = size_average def forward(self, input, target): if input.dim() > 2: input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W input = input.transpose(1, 2) # N,C,H*W => N,H*W,C input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C target = target.view(-1, 1) logpt = F.log_softmax(input) logpt = logpt.gather(1, target) logpt = logpt.view(-1) pt = logpt.data.exp() if self.alpha is not None: if self.alpha.type() != input.data.type(): self.alpha = self.alpha.type_as(input.data) at = self.alpha.gather(0, target.data.view(-1)) logpt = logpt * at loss = -1 * (1 - pt) ** self.gamma * logpt if self.size_average: return loss.mean() else: return loss.sum() ``` 可以看到,该代码中首先定义了一个 `FocalLoss` 类,该类继承自 PyTorch 中的 `nn.Module` 类,因此我们可以直接使用该类来定义我们的 Focal Loss 模型。 在 `__init__()` 方法中,我们定义了两个超参数 `gamma` 和 `alpha`。其中 `gamma` 的值默认为 2,即 Focal Loss 中的调节因子。`alpha` 表示每个类别的权重,如果 `alpha` 是一个浮点数,则表示正样本的权重,负样本的权重为 1 - `alpha`。如果 `alpha` 是一个列表,则它的长度应该等于类别数,每个元素表示每个类别的权重。 在 `forward()` 方法中,我们首先将输入的 `input` 和 `target` 二者都展平成一维向量,然后计算损失函数。具体而言,我们首先对 `input` 进行 softmax 操作,然后取出对应类别的概率值 `pt`,接着根据 `alpha` 权重计算加权的对数概率值 `logpt`。最后根据 Focal Loss 的公式计算损失,并返回平均值或总和。
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

十小大

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值