深度学习| Focal Loss(包含代码实现)

前言:最近在解决图像类别不平衡的问题,之前介绍了DiceLoss,试了代码虽然又改善但还没解决问题。我要处理图像样本类别属于极度不均衡,了解到FocalLoss也能解决这个问题,于是就想写这篇文章作为记录。

介绍

解决什么问题:Focal Loss解决的是深度学习遇到类别不平衡的问题,直接用交叉熵损失函数计算损失函数,会使得最终结果偏向于常见类别。

如何解决这个问题:Focal Loss在交叉熵函数的基础上引入了超参数,增大类别少的样本的权重,以及调整易分类样本和困难样本之间的权重关系。

原理和公式

Focal Loss其实是在交叉熵损失函数(Cross Entropy Loss)上改进过来的。

交叉熵损失函数(Cross Entropy Loss)
H ( y , y ^ ) = − 1 N ∑ i = 1 N [ y i l o g ( y ^ i ) + ( 1 − y i ) l o g ( 1 − y ^ i ) ] H(y,\widehat{y})=- \frac{1}{N} \sum_{i=1}^N[y_ilog(\widehat{y}_i)+(1-y_i)log(1-\widehat{y}_i)] H(y,y )=N1i=1N[yilog(y i)+(1yi)log(1y i)]
这是一个二分类的CE公式,其中y是真实标签, y ^ \widehat{y} y 是预测值,N是样本的数量。
原理上,每个样本都会计算一个损失,然后对所有样本的损失求平均。
对于图像来说,这里的N可以看作是图像像素点的个数, y ^ \widehat{y} y 是预测每个像素点的值,y是每个像素点标签的值,一张图像的交叉熵其实就是计算每个像素点预测值和标签插值的平均。

这个CE公式也可以写成如下形式:
C E ( p t ) = − l o g ( p t ) CE(pt)=-log(pt) CE(pt)=log(pt)
p t = {   p , y = 1   1 − p , o t h e r w i s e p_t= \begin{cases} \ p, & y=1\\ \ 1-p, & otherwise \end{cases} pt={ p, 1p,y=1otherwise
p t p_t pt表示预测值和真实值之间的差。

Focal Loss公式:
在CE的基础上引入了超参数 γ \gamma γ α \alpha α,每个样本的损失构成了如下公式:
F L ( p t ) = − α ( 1 − p t ) γ l o g ( p t ) = α ( 1 − p t ) γ C E ( p t ) FL(p_t)=-\alpha(1-p_t)^\gamma log(p_t) =\alpha(1-p_t)^\gamma CE(pt) FL(pt)=α(1pt)γlog(pt)=α(1pt)γCE(pt)
其中 p t p_t pt是该样本某个类别的预测值,Focal Loss类别一般采用one-hot编码; α \alpha α是给不同类别样本加的权重,对于正样本比较少,就可以加大权重; γ \gamma γ的作用在于如果当前样本预测值 p t p_t pt比较大,就是易分类样本,就会使得 ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ减小。
其实也就相当于计算每个样本交叉熵前面加多了两个权重,一个是类别权重,一个是样本难易权重。类别权重可以更重视类别占比小的;样本难以权重可以更加关注困难样本。

所以实际上Focal Loss是解决了两个问题:样本不均+难易样本。

γ \gamma γ α \alpha α如何确定

在Focal Loss论文中,作者通过搜索一个范围来确定两个参数的最优解,最后给出的结果是 γ = 2 \gamma=2 γ=2 α = 0.25 \alpha=0.25 α=0.25。在该论文任务中,正样本是大大少于负样本的,而正样本参数 α = 0.25 \alpha=0.25 α=0.25,负样本参数 α = 0.75 \alpha=0.75 α=0.75,非常反直觉。经过 ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ p t γ {p_t}^\gamma ptγ后,正负样本之间的形式会逆转,还要通过 α \alpha α给正样本降权。

所以 γ \gamma γ α \alpha α的确定更多还是实验经验的结果,没有什么理论上的方法。

代码

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.cross_entropy_loss = CrossEntropyLoss2d()
    
    def forward(self, inputs, targets):
    	# CE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')# 要求inputs和targets张量形状一样
        CE_loss = self.cross_entropy_loss(inputs, targets)# inputs可以是NxCxHxW,targets可以是NxHxW,会自动对其张量
        pt = torch.exp(-CE_loss) # 预测正确的概率
        F_loss = self.alpha * (1-pt)**self.gamma * CE_loss
        
        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss

代码理解上没什么难的,基本就是照着Focal Loss的公式照着写的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值