非平衡数据损失函数之Focal Loss for Dense Object Detection

非平衡数据损失函数之Focal Loss for Dense Object Detection


最近在思考非平衡数据集的分类问题,总是觉得交叉熵(CE)即使带了权重也并不够orientied:在实际最近跑的实验中,发现即使正类样本能够全部召回,但是代价是false positve急剧增多。

Source来源

原论文来自何凯明大神,主要解决了one-stage网络识别率普遍低于two-stage网络的问题,其指出其根本原因是样本类别不均衡,因此通过改变传统的loss(CE)变为focal loss,瞬间提升了one-stage网络的准确率。

Problem setup问题背景

文章指出,one-stage网络是在训练阶段,极度不平衡的类别数量导致准确率下降,一张图片里背景负类样本远远高于正类样本,这导致分类负类样本的数目占据loss的极大部分,因此,这种不平衡导致了模型会把更多的重心放在背景样本的学习上去。
Focal loss的做法是改变原有的loss计算方式,避免过度沉溺与easy examples。

Loss function定义

先放公式:
在这里插入图片描述
对于普通的CE,由于负样本数量巨大,正样本很少,所以负样本被错分为正样本的的loss会占据loss的主导。那么好的做法就是,尽量减少负样本loss所占的比例,或者增大正样本被错分为负样本的loss所占的比例。

首先直接在CE前面乘以一个参数α,这样可以方便控制正负样本loss所占的比例,即如果是正样本,那么下式表示的就是正样本被错分为负样本的loss,接着乘以α用于调整这个loss的大小,显然应该放大这个loss:
在这里插入图片描述
然而,尽管这样可以做可以起到一些放大作用,但其效果也是不够的。
比如:如果分类的结果接近正确,正样本以0.9的概率被分为正样本,但这部分loss也会被放大,这是我们不希望看到的;此外,预测为0.4的正样本和预测为0.6的正样本的loss在这里相差也是不大的。
因此,我们希望把这个差距拉开,希望看到的是,被分类的足够好的样本loss不需要太大的权重,而被错分严重的,我们需要将他的loss放大,错分越严重loss应该被放大的越多,因此可以用下面的指数函数来实现:
在这里插入图片描述
由下面这张图可以看出来,当γ为5的时候,预测概率小于0.5的正样本因为乘了系数,可以将loss放到很大;而大于0.5的分类的很好的正样本的loss则被抑制为接近0。
在这里插入图片描述

Code realization代码实现

@Descripttion: This is Aoru Xue's demo, which is only for reference.
@version: 
@Author: Aoru Xue
@Date: 2018-12-26 08:04:34
@LastEditors  : Aoru Xue
@LastEditTime : 2018-12-26 08:16:09
'''
import torch
import torch.nn as nn

class FocalLoss(nn.Module):
    def __init__(self,gamma = 0.5):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
    def forward(self,x,y):# (b,len) (b,1)
        '''
        FL = -(1-pt)**gamma * log(pt) if gt == 1
           = -pt**gamma * log(1-pt) if gt == 0
        利用乘法省去if
        '''
        pt = torch.sigmoid(x).view(-1,)

        losses = -(1 - pt)**self.gamma * torch.log(pt) * y - pt**self.gamma * torch.log(1-pt) * (1-y)
        return torch.sum(losses)

if __name__ == '__main__':
    focal_loss = FocalLoss()
    x = torch.Tensor([[0.1,0.5,0.7,0.8]])
    y = torch.LongTensor([[1,0,1,0]])
    loss = focal_loss(x,y)
    print(loss)

Conclusion总结

Focal loss能够避免梯度更新方向倾向easy examples主要有以下两点:

  1. Focal loss的本质是针对在非平衡数据集中,负类样本占据比例loss过多,而对正类样本错分的loss进行放大调节
  2. 不同程度错分的正类样本,被赋予不同的系数,并且随着置信度急剧下降,加强了对hard example的学习
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值