focalLoss焦点损失函数,主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。
FocalLoss是在交叉熵损失函数的基础上修改的得来的
其中y表示真实样本;p表示预测得到的概率;平衡因子alpha,用来平衡正负样本本身的比例不均;gamma调节简单样本权重降低的速率,当gamma为0时即为交叉熵损失函数,当gamma增加时,调整因子的影响也在增加。实验发现gamma为2是最优;alpha取0.25,即正样本要比负样本占比小,这是因为负例易分。
def __init__(self, class_num, alpha=None, gamma=1.5, size_average=True):
super(FocalLoss, self).__init__()
if alpha is None:
#self.alpha = Variable(torch.ones(class_num, 1))
#self.alpha[0] = 0.3
self.alpha = Variable(torch.tensor([0.3,1,1,1,1,1,1,1]))
else:
if isinstance(alpha, Variable):
self.alpha = alpha
else:
self.alpha = Variable(alpha)
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
def forward(self, inputs, targets):
N = inputs.size(0)
C = inputs.size(1)
# P = F.softmax(inputs)
P = inputs.softmax(dim=1)
class_mask = inputs.data.new(N, C).fill_(0)
class_mask = Variable(class_mask)
ids = targets.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.)
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda()
alpha = self.alpha[ids.data.view(-1)]
probs = (P*class_mask).sum(1).view(-1,1)
log_p = probs.log()
#print('probs size= {}'.format(probs.size()))
#print(probs)
batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
#print('-----bacth_loss------')
#print(batch_loss)
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss