详解Focal Loss以及PyTorch代码

原理

从17年被RetinaNet提出,Focal Loss 一直备受好评。由于其着重关注分类较差的样本的思想,Focal loss以简单的形式,一定程度解决了样本的难例挖掘,样本不均衡的问题。

普通的Cross Entropy

C E ( p t ) = − a t l o g ( p t ) CE(p_t) = -a_t log(p_t) CE(pt)=atlog(pt)
a t a_t at是平衡因子。

Focal Loss

F L ( p t ) = − ( 1 − p t ) r l o g ( p t ) FL(p_t) = -(1-p_t)^rlog(p_t) FL(pt)=(1pt)rlog(pt)
在log前面加上 ( 1 − p t ) (1-p_t) (1pt)是focal loss的核心。假设 r r r设置为2。当 p t p_t pt为0.9,说明网络已经将这个样本分的很好了,那么 ( 1 − p t ) 2 (1-p_t)^2 (1pt)2 为0.01,呈指数级降低了这个样本对损失函数的贡献。当 p t p_t pt为0.1,说明网络对样本还不具有很好地分类能力,那么 ( 1 − p t ) 2 (1-p_t)^2 (1pt)2为0.81。 简单言之,focal加大了对难分类样本的关注。

代码

来自知乎大佬

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:  # alpha 是平衡因子
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma  # 指数
        self.class_num = class_num  # 类别数目
        self.size_average = size_average  # 返回的loss是否需要mean一下

    def forward(self, inputs, targets):
        # target : N, 1, H, W
        inputs = inputs.permute(0, 2, 3, 1)
        targets = targets.permute(0, 2, 3, 1)
        num, h, w, C = inputs.size()
        N = num * h * w
        inputs = inputs.reshape(N, -1)   # N, C
        targets = targets.reshape(N, -1)  # 待转换为one hot label
        P = F.softmax(inputs, dim=1)  # 先求p_t
        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)  # 得到label的one_hot编码

        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()  # 如果是多GPU训练 这里的cuda要指定搬运到指定GPU上 分布式多进程训练除外
        alpha = self.alpha[ids.data.view(-1)]
        # y*p_t  如果这里不用*, 还可以用gather提取出正确分到的类别概率。
        # 之所以能用sum,是因为class_mask已经把预测错误的概率清零了。
        probs = (P * class_mask).sum(1).view(-1, 1)
        # y*log(p_t)
        log_p = probs.log()
        # -a * (1-p_t)^2 * log(p_t)
        batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p

        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

在代码我写了清晰的注释。该Focal loss可适用于大于2类的分类任务。

  • 6
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值