本文主要从二值交叉熵损失函数出发,通过代码实现的方式,去更好地理解Focal Loss对于数据不平衡问题、难易样本问题损失是如何权衡的。
1. 首先我们给出比较官方一些的代码,具体就是mmdet中的py_sigmoid_focal_loss函数。
def py_sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
"""
PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
"""
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
2. 根据理解,自己写得更简单直观的代码。
class BCEFocalLoss(nn.Module):
def __init__(self,alpha=0.25,gamma=2):
super(BCEFocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self,logits,label):
label = label.unsqueeze(1) # size(N, 1)
assert label.size() == logits.size()
probs = torch.sigmoid(logits)
pos_loss = -label*self.alpha*probs.log()*(1-probs)**self.gamma
neg_loss = -(1-label)*(1-self.alpha)*(1-probs).log()*probs**self.gamma
loss = (pos_loss + neg_loss).mean()
return loss
3. 简单理论理解
比如,我们在做猫狗分类的任务,其中猫咪的图片有1000张,狗子的图片有300张,常见的二值交叉熵损失函数会倾向于学习到更多关于猫咪的知识,与此同时,会学到很少关于狗子的知识,这显然会让我们的分类器在识别狗子时容易失误,可以认为模型缺乏对狗子的理解。
因此,对于猫咪的图片其预测概率会更加置信,接近于1,此时focal loss的调制因子就起到了一种约束作用,其中 会更加接近于0,而对于分类不准确狗子的样本,损失基本没有改变,整体而言,相当于增加了分类不准确样本在损失函数中的权重。
上述的描述是从样本量方面解释了focal loss对于难易样本的约束,宏观理解就是样本量大的通常更加容易学习,样本量少的损失通常更加容易被样本量大的损失盖住,降低其损失影响。
当然,不管是样本多的类,还是样本少的类,都是存在难易样本的,因此,focal loss对于这种情况也是发挥作用的。
4. 参数分析
(1)其中gamma作用用于调节难易样本对于总loss的权重,其值越大,调整因子的影响也越大,这里最佳取值在实验中设置为2。
(2)其中平衡因子alpha,主要用来平衡正负样本比例不均的,从理论来讲,对于正样本,比如狗子图片,其数量相对来说更少,我们应该采用一个大于0.5的alpha值,来平衡类别之间的权重,但实际实验中,论文采用了更加合适的取值0.25,这主要可能是因为gamma参数的影响占据了更大的作用,alpha在这里起到了一个额外的辅助微调整作用,避免了整体的矫枉过正或者力有不逮的情况。