paper: Focal Loss for Dense Object Detection
主要是解决class imbalance
two stage的detector,会有候选的region,而one stage的detector会存在大量的容易分类的背景,这些很容易分类的背景会造成一定程度的训练浪费。
focal loss解决class imbalance,同时降低容易分类的weight,使训练更集中到难分类的上面
把cross entropy的概率化简成pt
解决class imbalance
In practice alpha may be set by inverse class frequency
or treated as a hyperparameter to set by cross validation.
降低容易分类的weight,可以看出pt越大,比如0.9,说明很容易区分了,gamma=2时,它的weight就会缩小100倍
将两者结合,既解决class imbalance,又降低容易区分的weight
实现中每个class定义一个alpha
logpt = F.log_softmax(inputs,dim=1)
logpt = logpt.gather(1, target)
logpt = logpt.view(-1)
pt = logpt.exp()
at = alpha.gather(0, target.view(-1))
logpt = logpt * at
loss = -1 * (1-pt) ** gamma * logpt
#return loss.sum() or loss.mean()