focal loss

在物体检测中,一张图像可以生成成千的candidate locations,而其中只有少数的candidate locations包含object,也就是绝大多数的candidate locations都为一个类别(背景),导致类别不均衡.在训练的时候,这些绝大多数的candidate locations占损失函数的大部分,且由于都属于一个类别,容易分类,因此导致模型的优化方向很难朝着我们希望的能检测object的方向去优化.

针对训练数据类别不均衡的问题,文献Focal Loss for Dense Object Detection,提出了focal loss的思想,具体来说就是在交叉损失熵的基础上,通过减小容易分类样本的权重,从而使得模型在训练时更加专注于难分类的样本,公式为:
这里写图片描述

式中, γ 为focusing parameter, γ>=0 .

同时,文献通过引入权重 αt 控制正负样本的权重,

这里写图片描述

focal loss代码实现为:

def focal_loss(onehot_labels, cls_preds,
                            alpha=0.25, gamma=2.0, name=None, scope=None):
    """Compute softmax focal loss between logits and onehot labels

    logits and onehot_labels must have same shape [batchsize, num_classes] and
    the same data type (float16, 32, 64)

    Args:
      onehot_labels: Each row labels[i] must be a valid probability distribution
      cls_preds: Unscaled log probabilities
      alpha: The hyperparameter for adjusting biased samples, default is 0.25
      gamma: The hyperparameter for penalizing the easy labeled samples
      name: A name for the operation (optional)

    Returns:
      A 1-D tensor of length batch_size of same type as logits with softmax focal loss
    """
    with tf.name_scope(scope, 'focal_loss', [cls_preds, onehot_labels]) as sc:
        logits = tf.convert_to_tensor(cls_preds)
        onehot_labels = tf.convert_to_tensor(onehot_labels)

        precise_logits = tf.cast(logits, tf.float32) if (
                        logits.dtype == tf.float16) else logits
        onehot_labels = tf.cast(onehot_labels, precise_logits.dtype)
        predictions = tf.nn.sigmoid(logits)
        predictions_pt = tf.where(tf.equal(onehot_labels, 1), predictions, 1.-predictions)
        # add small value to avoid 0
        epsilon = 1e-8
        alpha_t = tf.scalar_mul(alpha, tf.ones_like(onehot_labels, dtype=tf.float32))
        alpha_t = tf.where(tf.equal(onehot_labels, 1.0), alpha_t, 1-alpha_t)
        losses = tf.reduce_sum(-alpha_t * tf.pow(1. - predictions_pt, gamma) * onehot_labels * tf.log(predictions_pt+epsilon),
                                     name=name, axis=1)
        return losses
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值