Focal Loss源码

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        self.alpha = torch.tensor(alpha).cuda()
        self.gamma = gamma
        self.size_average = size_average

    def forward(self, pred, target):
        # 如果模型最后没有 nn.Sigmoid(),那么这里就需要对预测结果计算一次 Sigmoid 操作
        # pred = nn.Sigmoid()(pred)

        # 展开 pred 和 target,此时 pred.size = target.size = (BatchSize,1)
        pred = pred.view(-1,1)
        target = target.view(-1,1)

        # 此处将预测样本为正负的概率都计算出来,此时 pred.size = (BatchSize,2)
        pred = torch.cat((1-pred,pred),dim=1)

        # 根据 target 生成 mask,即根据 ground truth 选择所需概率
        # 用大白话讲就是:
        # 当标签为 1 时,我们就将模型预测该样本为正类的概率代入公式中进行计算
        # 当标签为 0 时,我们就将模型预测该样本为负类的概率代入公式中进行计算
        class_mask = torch.zeros(pred.shape[0],pred.shape[1]).cuda()
        # 这里的 scatter_ 操作不常用,其函数原型为:
        # scatter_(dim,index,src)->Tensor
        # Writes all values from the tensor src into self at the indices specified in the index tensor.
        # For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.
        class_mask.scatter_(1, target.view(-1, 1).long(), 1.)#one_hot编码

        # 利用 mask 将所需概率值挑选出来
        probs = (pred * class_mask).sum(dim=1).view(-1,1)
        probs = probs.clamp(min=0.0001,max=1.0)

        # 计算概率的 log 值
        log_p = probs.log()

        # 根据论文中所述,对 alpha 进行设置(该参数用于调整正负样本数量不均衡带来的问题)
        alpha = torch.ones(pred.shape[0],pred.shape[1]).cuda()
        alpha[:,0] = alpha[:,0] * (1-self.alpha)
        alpha[:,1] = alpha[:,1] * self.alpha
        alpha = (alpha * class_mask).sum(dim=1).view(-1,1)

        # 根据 Focal Loss 的公式计算 Loss
        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p

         # Loss Function的常规操作,mean 与 sum 的区别不大,相当于学习率设置不一样而已
        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()

        return loss

同样咱们以下面这个target和pred去debug

if __name__ == '__main__':
    target = torch.tensor([0, 1])
    pred = torch.tensor([0.98, 0.07])
    target = Variable(target).cuda()
    pred = Variable(pred).cuda()
    focalloss = FocalLoss()
    loss = focalloss(pred, target)
    print(loss)

target:0代表背景,1代表前景

pred:预测的概率都是模型认为是前景的概率

先reshape,然后将1-pred与pred进行concat操作得到pred,(1-pred代表模型认为是背景的概率,pred则代表的是前景的概率。)class_mask会将target[0,1]转换成one-hot形式,class_mask = [[1,0],[0,1]].probs则是将one-hot编码中对应的概率值选出得到probs[[0.02],[0.07]],同时限制其取值最大1.0,最小0.001.这里的probs也就是原公式中的pt。然后取对数得到就是公式中的log(pt)。然后引入α,如果本来是背景类,那么前面的系数应该是(1-α),前景类则是α。拿到这些参数后则可以进行loss计算了,

 

 

 

 

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值