Focal loss详解及其实现

什么是Focal loss
Focal loss是何恺明大神提出的一种新的loss计算方案。
其具有两个重要的特点。
1、控制正负样本的权重
2、控制容易分类和难分类样本的权重

正负样本的概念如下:
一张图像可能生成成千上万的候选框,但是其中只有很少一部分是包含目标的的,有目标的就是正样本,没有目标的就是负样本。

容易分类和难分类样本的概念如下:
假设存在一个二分类,样本1属于类别1的pt=0.9,样本2属于类别1的pt=0.6,显然前者更可能是类别1,其就是容易分类的样本;后者有可能是类别1,所以其为难分类样本。

如何实现权重控制呢,请往下看:

控制正负样本的权重
如下是常用的交叉熵loss,以二分类为例:
在这里插入图片描述
我们可以利用如下Pt简化交叉熵loss。
在这里插入图片描述
此时:
在这里插入图片描述
想要降低负样本的影响,可以在常规的损失函数前增加一个系数αt。与Pt类似,当label=1的时候,αt=α;当label=otherwise的时候,αt=1 - α,a的范围也是0到1。此时我们便可以通过设置α实现控制正负样本对loss的贡献。
在这里插入图片描述
其中:
在这里插入图片描述
分解开就是:
在这里插入图片描述控制容易分类和难分类样本的权重
按照刚才的思路,一个二分类,样本1属于类别1的pt=0.9,样本2属于类别1的pt=0.6,也就是 是某个类的概率越大,其越容易分类 所以利用1-Pt就可以计算出其属于容易分类或者难分类。
具体实现方式如下。
在这里插入图片描述
1、当pt趋于0的时候,调制系数趋于1,对于总的loss的贡献很大。当pt趋于1的时候,调制系数趋于0,也就是对于总的loss的贡献很小。
2、当γ=0的时候,focal loss就是传统的交叉熵损失,可以通过调整γ实现调制系数的改变。

两种权重控制方法合并
通过如下公式就可以实现控制正负样本的权重和控制容易分类和难分类样本的权重。
在这里插入图片描述
实现方式

def focal(alpha=0.25, gamma=2.0):
    def _focal(y_true, y_pred):#两个输入
        # y_true [batch_size, num_anchor, num_classes+1]三个维度:样本数量、先验框的数量、种类数量+1
        # y_pred [batch_size, num_anchor, num_classes]
        labels         = y_true[:, :, :-1]
        anchor_state   = y_true[:, :, -1]  # -1 是需要忽略的, 0 是背景, 1 是存在目标 区分每一个anchor属于正样本还是负样本
        classification = y_pred

        # 找出存在目标的先验框 取出属于正样本的部分
        indices_for_object        = backend.where(keras.backend.equal(anchor_state, 1))
        labels_for_object         = backend.gather_nd(labels, indices_for_object)#正样本的标签
        classification_for_object = backend.gather_nd(classification, indices_for_object)#正样本的预测结果

        # 计算每一个先验框应该有的权重 计算正样本的权重
        alpha_factor_for_object = keras.backend.ones_like(labels_for_object) * alpha #αt
        alpha_factor_for_object = backend.where(keras.backend.equal(labels_for_object, 1), alpha_factor_for_object, 1 - alpha_factor_for_object) 
        focal_weight_for_object = backend.where(keras.backend.equal(labels_for_object, 1), 1 - classification_for_object, classification_for_object)#1-pt
        focal_weight_for_object = alpha_factor_for_object * focal_weight_for_object ** gamma

        # 将权重乘上所求得的交叉熵 所有正样本的loss
        cls_loss_for_object = focal_weight_for_object * keras.backend.binary_crossentropy(labels_for_object, classification_for_object)

        # 找出实际上为背景的先验框 取出所有属于负样本的部分
        indices_for_back        = backend.where(keras.backend.equal(anchor_state, 0))
        labels_for_back         = backend.gather_nd(labels, indices_for_back)
        classification_for_back = backend.gather_nd(classification, indices_for_back)

        # 计算每一个先验框应该有的权重 计算每一个负样本应该有的权重
        alpha_factor_for_back = keras.backend.ones_like(labels_for_back) * (1 - alpha)
        focal_weight_for_back = classification_for_back
        focal_weight_for_back = alpha_factor_for_back * focal_weight_for_back ** gamma

        # 将权重乘上所求得的交叉熵 计算所有负样本的loss
        cls_loss_for_back = focal_weight_for_back * keras.backend.binary_crossentropy(labels_for_back, classification_for_back)

        # 标准化,实际上是正样本的数量  
        normalizer = tf.where(keras.backend.equal(anchor_state, 1))
        normalizer = keras.backend.cast(keras.backend.shape(normalizer)[0], keras.backend.floatx())
        normalizer = keras.backend.maximum(keras.backend.cast_to_floatx(1.0), normalizer)

        # 将所获得的loss除上正样本的数量
        cls_loss_for_object = keras.backend.sum(cls_loss_for_object)
        cls_loss_for_back = keras.backend.sum(cls_loss_for_back)

        # 总的loss
        loss = (cls_loss_for_object + cls_loss_for_back)/normalizer

        return loss
    return _focal
  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Focal Loss是一种用于处理类别不平衡问题的损失函数。在训练深度学习模型时,由于数据集中不同类别的样本数量往往存在较大的差异,因此训练出的模型容易出现对数量较大的类别表现良好,对数量较小的类别表现较差的情况。Focal Loss通过调整样本的权重,使得模型更加关注难以分类的样本,从而提高模型在数量较小的类别上的性能。 下面是使用PyTorch实现多分类Focal Loss的代码: ``` import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, alpha=1, gamma=2, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss if self.reduction == 'mean': return torch.mean(focal_loss) elif self.reduction == 'sum': return torch.sum(focal_loss) else: return focal_loss ``` 在这里,我们定义了一个名为FocalLoss的自定义损失函数,并在其构造函数中定义了三个参数。alpha参数用于平衡每个类别的权重,gamma参数用于调整样本难度的权重,reduction参数用于指定损失函数的计算方式(mean或sum)。 在forward函数中,我们首先计算普通的交叉熵损失(ce_loss),然后计算每个样本的难度系数(pt),最后计算Focal Loss(focal_loss)。最后根据reduction参数的设定,返回损失函数的值。 在使用Focal Loss时,我们需要在训练过程中将损失函数替换为Focal Loss即可。例如,如果我们使用了PyTorch的nn.CrossEntropyLoss作为损失函数,我们可以将其替换为FocalLoss: ``` criterion = FocalLoss(alpha=1, gamma=2) ``` 这样,在训练过程中就会使用Focal Loss作为损失函数,从而提高模型在数量较小的类别上的性能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值