Focal Loss是 Kaiming 大神团队在他们的论文Focal Loss for Dense Object Detection 提出来的损失函数。Focal Loss的引入主要是为了解决难易样本数量不平衡(注意,有区别于正负样本数量不平衡)的问题,实际可以使用的范围非常广泛,为了方便解释,还是拿目标检测的应用场景来说明:单阶段的目标检测器通常会产生高达100k的候选目标,只有极少数是正样本,正负样本数量非常不平衡。我们在计算分类的时候常用的损失——交叉熵的公式如下:
为了解决正负样本不平衡的问题,我们通常会在交叉熵损失的前面加上一个参数alpha ,即:
尽管 alpha 平衡了正负样本,但对难易样本的不平衡没有任何帮助。而实际上,目标检测中大量的候选目标都是像下图一样的易分样本。于是论文中想到,把高置信度§样本的损失再降低一些不就好了,其形式为:
其代码实现可以参考下面这段代码:
def binary_focal_loss_fixed(y_true, y_pred):
"""
:param y_true: A tensor of the same shape as `y_pred`
:param y_pred: A tensor resulting from a sigmoid
:return: Output tensor.
"""
pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
epsilon = K.epsilon()
# clip to prevent NaN's and Inf's
pt_1 = K.clip(pt_1, epsilon, 1. - epsilon)
pt_0 = K.clip(pt_0, epsilon, 1. - epsilon)
return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1)) \
-K.sum((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0))