非平衡数据损失函数之Focal Loss for Dense Object Detection
最近在思考非平衡数据集的分类问题,总是觉得交叉熵(CE)即使带了权重也并不够orientied:在实际最近跑的实验中,发现即使正类样本能够全部召回,但是代价是false positve急剧增多。
Source来源
原论文来自何凯明大神,主要解决了one-stage网络识别率普遍低于two-stage网络的问题,其指出其根本原因是样本类别不均衡,因此通过改变传统的loss(CE)变为focal loss,瞬间提升了one-stage网络的准确率。
Problem setup问题背景
文章指出,one-stage网络是在训练阶段,极度不平衡的类别数量导致准确率下降,一张图片里背景负类样本远远高于正类样本,这导致分类负类样本的数目占据loss的极大部分,因此,这种不平衡导致了模型会把更多的重心放在背景样本的学习上去。
Focal loss的做法是改变原有的loss计算方式,避免过度沉溺与easy examples。
Loss function定义
先放公式:
对于普通的CE,由于负样本数量巨大,正样本很少,所以负样本被错分为正样本的的loss会占据loss的主导。那么好的做法就是,尽量减少负样本loss所占的比例,或者增大正样本被错分为负样本的loss所占的比例。
首先直接在CE前面乘以一个参数α,这样可以方便控制正负样本loss所占的比例,即如果是正样本,那么下式表示的就是正样本被错分为负样本的loss,接着乘以α用于调整这个loss的大小,显然应该放大这个loss:
然而,尽管这样可以做可以起到一些放大作用,但其效果也是不够的。
比如:如果分类的结果接近正确,正样本以0.9的概率被分为正样本,但这部分loss也会被放大,这是我们不希望看到的;此外,预测为0.4的正样本和预测为0.6的正样本的loss在这里相差也是不大的。
因此,我们希望把这个差距拉开,希望看到的是,被分类的足够好的样本loss不需要太大的权重,而被错分严重的,我们需要将他的loss放大,错分越严重loss应该被放大的越多,因此可以用下面的指数函数来实现:
由下面这张图可以看出来,当γ为5的时候,预测概率小于0.5的正样本因为乘了系数,可以将loss放到很大;而大于0.5的分类的很好的正样本的loss则被抑制为接近0。
Code realization代码实现
@Descripttion: This is Aoru Xue's demo, which is only for reference.
@version:
@Author: Aoru Xue
@Date: 2018-12-26 08:04:34
@LastEditors : Aoru Xue
@LastEditTime : 2018-12-26 08:16:09
'''
import torch
import torch.nn as nn
class FocalLoss(nn.Module):
def __init__(self,gamma = 0.5):
super(FocalLoss, self).__init__()
self.gamma = gamma
def forward(self,x,y):# (b,len) (b,1)
'''
FL = -(1-pt)**gamma * log(pt) if gt == 1
= -pt**gamma * log(1-pt) if gt == 0
利用乘法省去if
'''
pt = torch.sigmoid(x).view(-1,)
losses = -(1 - pt)**self.gamma * torch.log(pt) * y - pt**self.gamma * torch.log(1-pt) * (1-y)
return torch.sum(losses)
if __name__ == '__main__':
focal_loss = FocalLoss()
x = torch.Tensor([[0.1,0.5,0.7,0.8]])
y = torch.LongTensor([[1,0,1,0]])
loss = focal_loss(x,y)
print(loss)
Conclusion总结
Focal loss能够避免梯度更新方向倾向easy examples主要有以下两点:
- Focal loss的本质是针对在非平衡数据集中,负类样本占据比例loss过多,而对正类样本错分的loss进行放大调节
- 不同程度错分的正类样本,被赋予不同的系数,并且随着置信度急剧下降,加强了对hard example的学习