focal loss的pytorch实现代码理解

 

代码来自知乎大神    https://zhuanlan.zhihu.com/p/28527749

copy方便自己的学习 

class FocalLoss(nn.Module):
    r"""
        This criterion is a implemenation of Focal Loss, which is proposed in
        Focal Loss for Dense Object Detection.

            Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

        The losses are averaged across observations for each minibatch.
        Args:
            alpha(1D Tensor, Variable) : the scalar factor for this criterion
            gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
                                   putting more focus on hard, misclassified examples
            size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                However, if the field size_average is set to False, the losses are
                                instead summed for each minibatch.
    """

    def __init__(self, class_num=9, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average

    def forward(self, inputs, targets):   # input shape (N,C); target shape (N, )
        N = inputs.size(0)    #   batch大小
        C = inputs.size(1)    #   类别数
        P = F.softmax(inputs)

        class_mask = inputs.data.new(N, C).fill_(0)    # class_mask shape (N,C) 全0填充
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)    # ids shape (N,1)
        class_mask.scatter_(1, ids.data, 1.)  #  scatter_函数将src中数据根据index中的索引按照dim=1(行)的方向填进class_mask中, one-hot表示

        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]   #   alpha shape (N, 1) 全1


        probs = (P * class_mask).sum(1).view(-1, 1)    
        log_p = probs.log() 
   
        # 先softmax, 再log, 标准交叉熵

        # print('probs size= {}'.format(probs.size()))
        # print(probs)

        batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p

        # print('-----bacth_loss------')
        # print(batch_loss)

        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

 原知乎作者在评论中说他在试验中多有类别的alpha都取了1,  若想对不同类别赋予不同alpha值尝试, 可参考 镜中隐    https://blog.csdn.net/qq_36401512/article/details/91491205  的修改实现, 考虑不同类别的频率.

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值