一文读懂Focal Loss及Pytorch代码(详细注释)

广告位:

图像拼接论文精读专栏 —— 图像拼接领域论文全覆盖(包含数据集),省时省力读论文,带你理解晦涩难懂的论文算法,学习零散的知识和数学原理,并学会写图像拼接领域的论文(介绍、相关工作、算法、实验、结论、并附有参考文献,不用一篇一篇文章再找)

图像拼接论文源码精读专栏 —— 图像拼接有源码的论文全覆盖(有的自己复现),帮助你通过源码进一步理解论文算法,助你做实验,跑出拼接结果,得到评价指标RMSE、SSIM、PSNR等,并寻找潜在创新点和改进提升思路。

超分辨率重建专栏 —— 从SRCNN开始,带你读论文,写代码,复现结果,找创新点,完成论文。手把手教,保姆级攻略。帮助你顺利毕业,熟练掌握超分技术。

图像去噪专栏 —— 从DnCNN开始,100个经典的基于深度学习的图像去噪算法。读懂论文,看懂代码,复现结果,寻找创新,新手小白入门,保姆级攻略。帮助你顺利毕业,实现目标。

关注文章最底部微信公众号【十小大的底层视觉工坊】,获取最精炼版论文解读!

有需要的同学可以点上面链接看看。



前言

Focal Loss及RetinaNet原理见另一篇文章:【论文精读】Focal Loss for Dense Object Detection(RetinaNet)全文翻译及重点总结

本文介绍Focal Loss以及其Pytorch实现。


Focal Loss详解

直接上公式:
F L ( p t ) = − α t ( 1 − p t ) γ l o g ( p t ) FL(p_t)=-\alpha_t(1-p_t)^\gamma log(p_t) FL(pt)=αt(1pt)γlog(pt)
其中:
p t = { p , i f   y = 1 1 − p , o t h e r w i s e p_t= \begin{cases} p,\quad & if \ y=1 \\ 1-p, \quad & otherwise \end{cases} pt={p,1p,if y=1otherwise
交叉熵损失:
C E ( p t ) = − l o g ( p t ) CE(p_t) = -log(p_t) CE(pt)=log(pt)

  • p t p_t pt是类别(t个不同类别)概率(多分类就是softmax的结果),衡量样本难易程度,如果 p t p_t pt较大则是简单样本,较小则是困难样本;
  • α t \alpha_t αt是调节类别权重因子,它的值为第一类正样本权重,RetinaNet中设置为0.25,由于正样本远少于负样本,所以这样设置,让正样本的权重低,负样本权重为0.75;
  • γ \gamma γ是调节难易样本的权重因子,让模型快速关注困难样本。RetinaNet中设为2;
  • ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ整体控制损失大小,因为 γ = 2 \gamma=2 γ=2,所以 p t → 0 p_t \rightarrow 0 pt0时,loss大; p t → 1 p_t \rightarrow 1 pt1时,loss小;这样合理的权重分配可以让模型更好的学习训练;
  • α = 1 , γ = 0 \alpha = 1,\gamma = 0 α=1,γ=0时,FL就变成了CE。所以在复现的时候提供了一个思路,那就是定义前面两个权重因子,再乘上CE就得到了FL, p t p_t pt是softmax的结果,那么 l o g ( p t ) log(p_t) log(pt)就是log_softmax的结果。

实现思路

定义focal_loss类:

1.参数定义:alpha,gamma,num_classes等,应用到目标检测中可能还需要anchors等参数。

2.forward:focalloss是由CE乘以因子组成的,而CE是由NLL和log_softmax组成的。所以,先实现NLL和log_softmax,然后一步一步通过公式实现focal loss,主要是Tensor的变换。

3.输入参数设计:输入张量是[B,N,C]或[B,C],其中B是批量,N是预测框数量,C是类别数;[B,C]就是单纯的分类问题。真实标签是[B,N]或[B],总的标签数,如果是目标检测中则标签数就是B*N。


Focal Loss类的代码

from torch import nn
import torch
from torch.nn import functional as F

class focal_loss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, num_classes = 5, size_average=True):
        """
        focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)
        步骤详细的实现了 focal_loss损失函数.
        :param alpha:   阿尔法α,类别权重. 当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.255
        :param gamma:   伽马γ,难易样本调节参数. retainnet中设置为2
        :param num_classes:     类别数量
        :param size_average:    损失计算方式,默认取均值
        """
        super(focal_loss,self).__init__()
        self.size_average = size_average
        if isinstance(alpha,list):
            assert len(alpha)==num_classes   # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
            print(" --- Focal_loss alpha = {}, 将对每一类权重进行精细化赋值 --- ".format(alpha))
            self.alpha = torch.Tensor(alpha)
        else:
            assert alpha<1   #如果α为一个常数,则降低第一类的影响,在目标检测中为第一类
            print(" --- Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用 --- ".format(alpha))
            self.alpha = torch.zeros(num_classes)
            self.alpha[0] += alpha
            self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]

        self.gamma = gamma

    def forward(self, preds, labels):
        """
        focal_loss损失计算
        :param preds:   预测类别. size:[B,N,C] or [B,C]    分别对应与检测与分类任务, B批次, N检测框数, C类别数
        :param labels:  实际类别. size:[B,N] or [B]        [B*N个标签(假设框中有目标)],[B个标签]
        :return:
        """
                
        #固定类别维度,其余合并(总检测框数或总批次数),preds.size(-1)是最后一个维度
        preds = preds.view(-1,preds.size(-1))
        self.alpha = self.alpha.to(preds.device)
        
        #使用log_softmax解决溢出问题,方便交叉熵计算而不用考虑值域
        preds_logsoft = F.log_softmax(preds, dim=1) 
        
     	#log_softmax是softmax+log运算,那再exp就算回去了变成softmax
        preds_softmax = torch.exp(preds_logsoft)    
   
        # 这部分实现nll_loss ( crossentropy = log_softmax + nll)
        preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) 
        preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))

        self.alpha = self.alpha.gather(0,labels.view(-1))

        # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ
        
        #torch.mul 矩阵对应位置相乘,大小一致
        loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft) 
    
        #torch.t()求转置
        loss = torch.mul(self.alpha, loss.t())
        #print(loss.size()) [1,5]
        
        if self.size_average:
            loss = loss.mean()
        else:
            loss = loss.sum()
       
        return loss

注:根据自己需要修改num_classes的值


具体流程分析代码

先简单一点测试一下图片分类:

假设有三张图片,分类类别一共五类,三张图片的真实类别类别数分别为2,3,4

torch.manual_seed(50)
preds = torch.randn((3,5))
#preds = torch.randn((3,10,5))
print(preds)
preds = preds.view(-1,preds.size(-1))
print(preds.size())

labels = torch.tensor([2,3,4])
#labels = torch.tensor([2,3,4]*10)

print(labels.view(-1))
print(labels.view(-1,1))

看一下各张量值以及size()后面会用到:
在这里插入图片描述
计算log_softmax以及softmax:

preds_logsoft = F.log_softmax(preds, dim=1) # log_softmax
print(preds_logsoft)
preds_softmax = torch.exp(preds_logsoft)
print(preds_softmax)

结果:
在这里插入图片描述
从中取得每个图片分别在2,3,4类别的预测值:如何取得可见另一篇文章详解。

preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
print(preds_logsoft)    
preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) 
print(preds_softmax)   

结果:
在这里插入图片描述
计算loss并转置:gamma取2

a = torch.pow((1-preds_softmax), 2)
loss = -torch.mul(torch.pow((1-preds_softmax), 2), preds_logsoft) 
print(loss.t())
print(loss.t().size())

结果:
在这里插入图片描述
初始化alpha权重,并得到所需label的alpha权重

alpha = torch.tensor([0.25,0.75,0.75,0.75,0.75])
alpha = alpha.gather(0,labels.view(-1))
alpha

结果:
在这里插入图片描述
乘以alpha并求平均得到最终的focal loss

loss = torch.mul(alpha, loss.t())
loss = loss.mean()
loss

得到focal loss为:
在这里插入图片描述

PS:若要加入检测框数量,输入采用代码中注释的部分,本例中是假设每张图片有10个检测框(正样本),则总标签个数就是30个,相应地改变标签张量为30个元素即可。


Focal Loss(FL)与CrossEntropy(CE)对比:

criterion = focal_loss()
loss = criterion(preds, labels)
print("FL   loss",loss)
a=F.cross_entropy(preds, labels)
#a=nn.CrossEntropyLoss()
#a = a(preds,labels)
print('CE   loss',a)

结果说明FL比CE的损失低,效果好
在这里插入图片描述


总结

  1. focal loss在目标检测中预测类别时使用,是分类子网下的损失。
  2. 先log_softmax再取exp得到的softmax各类别之和不是严格为1。
  3. F.cross_entropy和nn.CrossEntropyLoss相同。
  4. gather函数详解见【Pytorch小知识】torch.gather()函数的用法及在Focal Loss中的应用(详细易懂)
  5. 可以根据分类或检测类别中的目标多少来自定义alpha的权重。测试代码见资源(资源链接)

没有硬件条件,需要云服务的同学可以扫码看看:
请添加图片描述

Focal Loss是一种用于处理类别不平衡问题的损失函数。在训练深度学习模型时,由于数据集中不同类别的样本数量往往存在较大的差异,因此训练出的模型容易出现对数量较大的类别表现良好,对数量较小的类别表现较差的情况。Focal Loss通过调整样本的权重,使得模型更加关注难以分类的样本,从而提高模型在数量较小的类别上的性能。 下面是使用PyTorch实现多分类Focal Loss代码: ``` import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, alpha=1, gamma=2, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss if self.reduction == 'mean': return torch.mean(focal_loss) elif self.reduction == 'sum': return torch.sum(focal_loss) else: return focal_loss ``` 在这里,我们定义了一个名为FocalLoss的自定义损失函数,并在其构造函数中定义了三个参数。alpha参数用于平衡每个类别的权重,gamma参数用于调整样本难度的权重,reduction参数用于指定损失函数的计算方式(mean或sum)。 在forward函数中,我们首先计算普通的交叉熵损失(ce_loss),然后计算每个样本的难度系数(pt),最后计算Focal Lossfocal_loss)。最后根据reduction参数的设定,返回损失函数的值。 在使用Focal Loss时,我们需要在训练过程中将损失函数替换为Focal Loss即可。例如,如果我们使用了PyTorch的nn.CrossEntropyLoss作为损失函数,我们可以将其替换为FocalLoss: ``` criterion = FocalLoss(alpha=1, gamma=2) ``` 这样,在训练过程中就会使用Focal Loss作为损失函数,从而提高模型在数量较小的类别上的性能。
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

十小大

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值