代码来自知乎大神 https://zhuanlan.zhihu.com/p/28527749
copy方便自己的学习
class FocalLoss(nn.Module):
r"""
This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""
def __init__(self, class_num=9, alpha=None, gamma=2, size_average=True):
super(FocalLoss, self).__init__()
if alpha is None:
self.alpha = Variable(torch.ones(class_num, 1))
else:
if isinstance(alpha, Variable):
self.alpha = alpha
else:
self.alpha = Variable(alpha)
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
def forward(self, inputs, targets): # input shape (N,C); target shape (N, )
N = inputs.size(0) # batch大小
C = inputs.size(1) # 类别数
P = F.softmax(inputs)
class_mask = inputs.data.new(N, C).fill_(0) # class_mask shape (N,C) 全0填充
class_mask = Variable(class_mask)
ids = targets.view(-1, 1) # ids shape (N,1)
class_mask.scatter_(1, ids.data, 1.) # scatter_函数将src中数据根据index中的索引按照dim=1(行)的方向填进class_mask中, one-hot表示
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda()
alpha = self.alpha[ids.data.view(-1)] # alpha shape (N, 1) 全1
probs = (P * class_mask).sum(1).view(-1, 1)
log_p = probs.log()
# 先softmax, 再log, 标准交叉熵
# print('probs size= {}'.format(probs.size()))
# print(probs)
batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
# print('-----bacth_loss------')
# print(batch_loss)
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
原知乎作者在评论中说他在试验中多有类别的alpha都取了1, 若想对不同类别赋予不同alpha值尝试, 可参考 镜中隐 https://blog.csdn.net/qq_36401512/article/details/91491205 的修改实现, 考虑不同类别的频率.