Focal Loss损失函数详解

       Focal Loss损失函数是何凯明大神在RetinaNet网络中提出的,解决了one-stage目标检测中正负样本极不平衡和难分类样本学习的问题,下面我们来详细理解一下该函数。

背景

  • 什么是正负样本极不平衡?

       目标检测算法为了定位目标会生成大量的anchor box,而一幅图中目标(正样本)个数很少,大量的anchor box处于背景区域(负样本),这就导致了正负样本极不平衡

  • two-stage为什么可以避免样本极不平衡?

       two-stage方法在第一阶段生成候选框,RPN只是对anchor box进行简单背景和前景的区分,并不对类别进行区分,经过这一轮处理,过滤掉了大部分属于背景的anchor box,较大程度降低了anchor box正负样本的不平衡性(注意:只是减轻了样本不平衡并没有解决样本不平衡);同时在第二阶段采用启发式采样(如:正负样本比1:3)或者OHEM进一步减轻正负样本不平衡的问题。

  • one-stage为什么不能避免样本极不平衡?

       one-stage方法为了提高检测速度,舍弃了生成候选框这一阶段,直接对anchor box进行难度更大的细分类(不只是区分前景背景,还区分anchor box属于什么类别),缺少了对anchor box的筛选过程。

Focal Loss

交叉熵

二分类交叉熵损失函数

                   

现定义如下的p_{t}

                   

得到变形后的损失函数如下:

                 

平衡交叉熵

       一般为了解决类别不平衡的问题,会在损失函数中每个类别前增加一个权重因子\alpha _{i}α^{i} ∈ [0, 1]来协调类别不平衡。使用p_{t}类似的方式定义\alpha _{t},得到二分类平衡交叉熵损失函数:

                     

Focal Loss

       类别极度不平衡在训练中,易分类负样本占了损失函数大部分,支配了梯度,会压垮交叉熵损失函数。平衡交叉熵采用\alpha平衡正负样本的重要性,但是没有区分难易样本。Focal Loss在平衡交叉熵损失函数的基础上,增加一个调节因子降低易分类样本权重,聚焦于困难样本的训练,其定义如下:

                            

其中,(1-p_{t})^{\gamma }调节因子\gamma≥ 0是可调节的聚焦参数,下图展示了\gamma ∈ [0, 5]不同值时focal loss曲线

                      

下面分析一下Focal Loss的特点:

  • p_{t}很小时(样本难分,不管分的是否正确),调节因子趋近1,损失函数中样本的权重不受影响;当p_{t}很大时(样本易分,不管分的是否正确),调节因子趋近0,损失函数中样本的权重下降很多
  • 聚焦参数\gamma可以调节易分类样本权重的降低程度,\gamma越大权重降低程度越大

通过分析Focal Loss函数的特点可知,该损失函数降低了易分类样本的权重,聚焦在难分类样本上。

  • 6
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Focal Loss是一种用于处理类别不平衡问题的损失函数。在训练深度学习模型时,由于数据集中不同类别的样本数量往往存在较大的差异,因此训练出的模型容易出现对数量较大的类别表现良好,对数量较小的类别表现较差的情况。Focal Loss通过调整样本的权重,使得模型更加关注难以分类的样本,从而提高模型在数量较小的类别上的性能。 下面是使用PyTorch实现多分类Focal Loss的代码: ``` import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, alpha=1, gamma=2, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss if self.reduction == 'mean': return torch.mean(focal_loss) elif self.reduction == 'sum': return torch.sum(focal_loss) else: return focal_loss ``` 在这里,我们定义了一个名为FocalLoss的自定义损失函数,并在其构造函数中定义了三个参数。alpha参数用于平衡每个类别的权重,gamma参数用于调整样本难度的权重,reduction参数用于指定损失函数的计算方式(mean或sum)。 在forward函数中,我们首先计算普通的交叉熵损失(ce_loss),然后计算每个样本的难度系数(pt),最后计算Focal Loss(focal_loss)。最后根据reduction参数的设定,返回损失函数的值。 在使用Focal Loss时,我们需要在训练过程中将损失函数替换为Focal Loss即可。例如,如果我们使用了PyTorch的nn.CrossEntropyLoss作为损失函数,我们可以将其替换为FocalLoss: ``` criterion = FocalLoss(alpha=1, gamma=2) ``` 这样,在训练过程中就会使用Focal Loss作为损失函数,从而提高模型在数量较小的类别上的性能。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值