Tensorflow版本的Focal loss
1、区分logits,prob,prediction
logits: 是网络的原始输出,从代码中可以简单的理解为 logits = f (x, w) + bais。通常来说,输出的logits的维度是(batch_size, class_num)。
prob: 代表的是在分类输出时,每一个类别的概率。概率的总和为1。通常来讲,prob是logits经过softmax得到的概率分布。prob = softmax ( logits ),通常为(batch_size, class_num)。
prediction: 它是logits通过argmax之后的输出,通常为(batch_size)。
不同的loss函数要求的输入是不一样的,在focal loss 损失函数中的输入则是为最原始的输出logits
2、focal loss 损失函数
首先,我们得知道focal loss 其实是在交叉熵的损失函数的基础上进行修改的,首先看一下交叉熵的loss
其中 y’ 是经过softmax函数的输出,所以在0-1之间。可见普通的交叉熵对于正样本而言,输出概率越大损失越小。对于负样本而言,输出概率越小则损失越小。此时的损失函数在大量简单样本的迭代过程中比较缓慢且可能无法优化至最优。那么Focal loss是怎么改进的呢?