睿智的目标检测9——Focal loss详解及其实现

2020/5/13日更新

之前的focal_loss没有考虑存在目标的框中的其它种类标签为0,已经修改代码!请各位注意啦!

学习前言

最近觉得yolo3的训练效果还可以优化,除去改变特征提取网络外,还可以改变其loss的组成。
在这里插入图片描述

什么是Focal loss

Focal loss是何恺明大神提出的一种新的loss计算方案。
其具有两个重要的特点。

1、控制正负样本的权重

2、控制容易分类和难分类样本的权重

正负样本的概念如下:
一张图像可能生成成千上万的候选框,但是其中只有很少一部分是包含目标的的,有目标的就是正样本,没有目标的就是负样本。

容易分类和难分类样本的概念如下:
假设存在一个二分类,样本1属于类别1的pt=0.9,样本2属于类别1的pt=0.6,显然前者更可能是类别1,其就是容易分类的样本;后者有可能是类别1,所以其为难分类样本。

如何实现权重控制呢,请往下看:

控制正负样本的权重

如下是常用的交叉熵loss,以二分类为例:
在这里插入图片描述
我们可以利用如下Pt简化交叉熵loss。
在这里插入图片描述
此时:
在这里插入图片描述
想要降低负样本的影响,可以在常规的损失函数前增加一个系数αt。与Pt类似,当label=1的时候,αt=α;当label=otherwise的时候,αt=1 - α,a的范围也是0到1。此时我们便可以通过设置α实现控制正负样本对loss的贡献在这里插入图片描述
其中:
在这里插入图片描述
分解开就是:
在这里插入图片描述

控制容易分类和难分类样本的权重

按照刚才的思路,一个二分类,样本1属于类别1的pt=0.9,样本2属于类别1的pt=0.6,也就是 是某个类的概率越大,其越容易分类 所以利用1-Pt就可以计算出其属于容易分类或者难分类。
具体实现方式如下。
在这里插入图片描述
其中:
( 1 − p t ) γ (1-p_{t})^{γ} (1pt)γ
称为调制系数(modulating factor)

1、当pt趋于0的时候,调制系数趋于1,对于总的loss的贡献很大。当pt趋于1的时候,调制系数趋于0,也就是对于总的loss的贡献很小。
2、当γ=0的时候,focal loss就是传统的交叉熵损失,可以通过调整γ实现调制系数的改变。

两种权重控制方法合并

通过如下公式就可以实现控制正负样本的权重控制容易分类和难分类样本的权重
在这里插入图片描述

实现方式

def focal(alpha=0.25, gamma=2.0):
    def _focal(y_true, y_pred):
        # y_true [batch_size, num_anchor, num_classes+1]
        # y_pred [batch_size, num_anchor, num_classes]
        labels         = y_true[:, :, :-1]
        anchor_state   = y_true[:, :, -1]  # -1 是需要忽略的, 0 是背景, 1 是存在目标
        classification = y_pred

        # 找出存在目标的先验框
        indices_for_object        = backend.where(keras.backend.equal(anchor_state, 1))
        labels_for_object         = backend.gather_nd(labels, indices_for_object)
        classification_for_object = backend.gather_nd(classification, indices_for_object)

        # 计算每一个先验框应该有的权重
        alpha_factor_for_object = keras.backend.ones_like(labels_for_object) * alpha
        alpha_factor_for_object = backend.where(keras.backend.equal(labels_for_object, 1), alpha_factor_for_object, 1 - alpha_factor_for_object)
        focal_weight_for_object = backend.where(keras.backend.equal(labels_for_object, 1), 1 - classification_for_object, classification_for_object)
        focal_weight_for_object = alpha_factor_for_object * focal_weight_for_object ** gamma

        # 将权重乘上所求得的交叉熵
        cls_loss_for_object = focal_weight_for_object * keras.backend.binary_crossentropy(labels_for_object, classification_for_object)

        # 找出实际上为背景的先验框
        indices_for_back        = backend.where(keras.backend.equal(anchor_state, 0))
        labels_for_back         = backend.gather_nd(labels, indices_for_back)
        classification_for_back = backend.gather_nd(classification, indices_for_back)

        # 计算每一个先验框应该有的权重
        alpha_factor_for_back = keras.backend.ones_like(labels_for_back) * (1 - alpha)
        focal_weight_for_back = classification_for_back
        focal_weight_for_back = alpha_factor_for_back * focal_weight_for_back ** gamma

        # 将权重乘上所求得的交叉熵
        cls_loss_for_back = focal_weight_for_back * keras.backend.binary_crossentropy(labels_for_back, classification_for_back)

        # 标准化,实际上是正样本的数量
        normalizer = tf.where(keras.backend.equal(anchor_state, 1))
        normalizer = keras.backend.cast(keras.backend.shape(normalizer)[0], keras.backend.floatx())
        normalizer = keras.backend.maximum(keras.backend.cast_to_floatx(1.0), normalizer)

        # 将所获得的loss除上正样本的数量
        cls_loss_for_object = keras.backend.sum(cls_loss_for_object)
        cls_loss_for_back = keras.backend.sum(cls_loss_for_back)

        # 总的loss
        loss = (cls_loss_for_object + cls_loss_for_back)/normalizer

        return loss
    return _focal
  • 53
    点赞
  • 234
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 80
    评论
评论 80
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Bubbliiiing

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值