论文地址:https://arxiv.org/abs/1708.02002
代码地址:https://github.com/HeyLynne/FocalLoss_for_multiclass
1.是什么
Focal Loss是一个在目标检测领域常用的损失函数,它是何凯明大佬在RetinaNet网络中提出的,解决了目标检测中正负样本极不平衡和难分类样本学习的问题。Focal Loss通过引入一个可调节的聚焦参数,将容易分类的样本的权重降低,而将困难样本的权重提高,从而使得模型更加关注困难样本,提高了模型的性能。需要注意的是,Focal Loss损失函数容易受到噪声的干扰,因此训练集中标注的信息尽量不要出现错误的情况。
2.为什么?
目标检测算法大都是基于两种结构:一种是以R-CNN为代表的two-stage,proposal 驱动算法。这种算法在第一阶段针对目标样本生成一份比较稀疏的集合,第二阶段对这份集合进行分类和提取,两个阶段下来速度就大打折扣了。另一种是以YOLO,SSD为代表的one-stage的目标检测算法,只用一个阶段就完成目标样本的检测和回归,速度相对于two-stage目标检测算法自然是有所提升,但是效果却大打折扣。
为什么one-stage的目标检测算法效果要差于two-stage呢。文中认为这是因为训练过程中类别失衡造成的,在two-stage检测算法中,第一阶段已经过滤了大部分的背景,将目标缩小在一定的范围内。而对于one-stage检测来说样本中包含了大量没有目标的背景,这导致样本的比例失衡,训练的时候负样本过多,导致他的loss过大而淹没了正样本的loss不利于收敛。一种解决办法是难分负样本挖掘,然后对这些样本单独训练。
基于以上提出了Focal loss,降低易分样本的权重,提高难分样本的权重。
3.怎么样?
3.1交叉熵
首先我们先简单了解一下交叉熵。
在信息学中信息熵(entropy)是表示系统的混乱程度和确定性的。一条信息的信息量和他的确定程度有直接关系,如果他的确定程度很高那么我们不需要很大的信息量就可以了解这些信息,例如北京是中国的首都,我们是很确定的,不需要其他的信息就可以判断这条信息对不对。那么一个系统的熵如何计算呢:
它是表示系统的不确定性的度量,当x的状态越多信息熵就越大,当x均匀分布时熵最大。当我们的样本集有两个分布p(x) 表示真实分布, q(x)表示非真实分布,那么当我们用p(x)表示样本集的熵即为刚才我们说的信息熵。那么如果使用 q(x)表示样本的熵怎么表示呢?注意到此时样本的真实分布是p(x)这个就是交叉熵(cross entropy)了。
对于二分类问题来说,他的交叉熵是:
其中p表示y=1的概率,这里我们定义
那么交叉熵可以表示为:
这里我们来看一张收敛的模型在测试数据集中的梯度分布,图片来自困难样本(Hard Sample)处理方法。最左边梯度接近于0就是简单样本,简单样本的数量很多。中间部分是一些不同难度的样本,最右边就是loss很大的困难样本,这些样本在数量上相对于简单样本是非常少的,所以即使他们的梯度很大,但是如果使用交叉熵,那么他们对loss的贡献还是很少,所以他们还是很难学。
下图是文中所给的不同样本的loss分布,还是如我们刚才所讨论的。这些易分的样本loss虽然不高但是数量很多,所以导致困难样本的loss容易被这些简单样本所覆盖,导致他们更加难学习。而引入focal loss之后可以看到我们降低了简单样本的loss,从而提高了他们对梯度的贡献。那么什么是focal loss呢,我们下面将着重介绍focal loss.
3.2 Focal loss
对于二分类问题Focal loss计算如下:
对于那些概率较大的样本 趋近于0,可以降低它的loss值,而对于真实概率比较低的困难样本,(
对他们的loss影响并不大,这样一来我们可以通过降低简单样本loss的方法提高困难样本对梯度的贡献。同时为了提高误分类样本的权重,最终作者为Focal loss增加权重,Focal loss最终长这样:
3.3 代码实现
二分类
class Focal_Loss():
"""
二分类Focal Loss
"""
def __init__(self,alpha=0.25,gamma=2):
super(Focal_Loss,self).__init__()
self.alpha=alpha
self.gamma=gamma
def forward(self,preds,labels):
"""
preds:sigmoid的输出结果
labels:标签
"""
eps=1e-7
loss_1=-1*self.alpha*torch.pow((1-preds),self.gamma)*torch.log(preds+eps)*labels
loss_0=-1*(1-self.alpha)*torch.pow(preds,self.gamma)*torch.log(1-preds+eps)*(1-labels)
loss=loss_0+loss_1
return torch.mean(loss)
多分类
class Focal_Loss():
def __init__(self,weight,gamma=2):
super(Focal_Loss,self).__init__()
self.gamma=gamma
self.weight=weight
def forward(self,preds,labels):
"""
preds:softmax输出结果
labels:真实值
"""
eps=1e-7
y_pred =preds.view((preds.size()[0],preds.size()[1],-1)) #B*C*H*W->B*C*(H*W)
target=labels.view(y_pred.size()) #B*C*H*W->B*C*(H*W)
ce=-1*torch.log(y_pred+eps)*target
floss=torch.pow((1-y_pred),self.gamma)*ce
floss=torch.mul(floss,self.weight)
floss=torch.sum(floss,dim=1)
return torch.mean(floss)
参考: