【机器学习】focalloss原理以及pytorch实现

最近在做一个分类项目,发现很多“难样本”比较不好处理又特别重要,想试试FocalLoss。没找到pytorch相关实现,本来想研究pytorch的cross_entropy源码,稍微改一下(怕手残自己写的loss效率比较低),但是发现有点复杂,我的任务比较简单,改那玩意有点累。
我们知道,对于二分类:
c r o s s _ e n t r o p y ( y , y ^ ) = − y l o g y ^ − ( 1 − y ) l o g 1 − y ^ cross\_entropy(y,\hat y) = -ylog^{\hat y}-(1-y)log^{1-\hat y} cross_entropy(y,y^)=ylogy^1ylog1y^

c r o s s _ e n t r o p y ( y , y ^ ) = { − l o g y ^ y=1 − l o g 1 − y ^ y=0 cross\_entropy(y,\hat y) =\begin{cases} -log^{\hat y}& \text{y=1}\\ -log^{1-\hat y}& \text{y=0} \end{cases} cross_entropy(y,y^)={logy^log1y^y=1y=0
y ^ \hat y y^为模型预测概率

如果有一个正样本,模型预测结果为0.9,loss为-log(0.9)约等于0.046

还有一个正样本,模型预测结果为0.55,loss为-log(0.55)约等于0.260

这个预测为0.55的样本提供的loss是预测为0.9的样本的5.65

如果我把公式改成下面这样:
γ = 2 F o c a l L o s s ( y , y ^ ) = { − ( 1 − y ^ ) γ l o g y ^ y=1 − y ^ γ l o g 1 − y ^ y=0 \gamma=2\\ FocalLoss(y,\hat y) = \begin{cases} -(1-\hat y)^{\gamma}log^{\hat y}& \text{y=1}\\ -\hat y^{\gamma}log^{1-\hat y}& \text{y=0}\\ \end{cases} γ=2FocalLoss(y,y^)={(1y^)γlogy^y^γlog1y^y=1y=0
这时如果有一个正样本,模型预测结果为0.9,loss为-0.1*0.1*log(0.9)约等于0.00046

还有一个正样本,模型预测结果为0.55,loss为-0.45*0.45*log(0.55)约等于0.0526

这个预测为0.55的样本提供的loss是预测为0.9的样本的114.35

这样就可以让模型更加更加关注“难样本”

另外还可以给正负样本的loss添加权重,让模型更注重正/负样本

上代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, device, gamma, alpha):
        super(FocalLoss, self).__init__()
        #self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.device = device
        self.gamma = gamma
        self.alpha = alpha
        
    def forward(self, inputs, targets): 
        if self.device == 'cpu':
            # 计算正负样本权重
            alpha_factor = torch.ones(targets.shape) * self.alpha
            alpha_factor = torch.where(torch.eq(targets, 1), alpha_factor, 1. - alpha_factor)
            # 计算因子项
            focal_weight = torch.where(torch.eq(targets, 1), 1. - inputs, inputs)
            # 得到最终的权重
            focal_weight = alpha_factor * torch.pow(focal_weight, self.gamma)
            targets = targets.type(torch.FloatTensor) 
            # 计算标准交叉熵
            bce = F.binary_cross_entropy(inputs, targets)
            # focal loss 
            cls_loss = focal_weight * bce
        else:
            gpu_targets = targets.cuda()
            gpu_inputs = inputs.cuda()
            alpha_factor = torch.ones(gpu_targets.shape).cuda() * self.alpha
            alpha_factor = torch.where(torch.eq(gpu_targets, 1), alpha_factor, 1. - alpha_factor)
            focal_weight = torch.where(torch.eq(gpu_targets, 1), 1. - gpu_inputs, gpu_inputs)
            focal_weight = alpha_factor * torch.pow(focal_weight, self.gamma)
            targets = targets.type(torch.FloatTensor)
            bce = F.binary_cross_entropy(gpu_inputs, gpu_targets)
            focal_weight = focal_weight.cuda()
            cls_loss = focal_weight * bce

        return cls_loss.sum()

优化了一下,方便最后一层使用softmax,并且减少一些计算量

class FocalLoss(nn.Module):
    def __init__(self, device, gamma, alpha):
        super(FocalLoss, self).__init__()
        self.device = device
        self.gamma = gamma
        self.w = torch.tensor([1 - alpha, alpha],dtype=torch.float32,device=self.device)
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.nllloss = nn.NLLLoss(weight=self.w,reduction='mean')
        
    def forward(self, inputs, targets):
        inputs.to(self.device)
        targets.to(self.device)
        # 计算softmax,稳定算法
        inputs_log_softmax = self.log_softmax(inputs)
        inputs_softmax = torch.exp(inputs_log_softmax)
        # 计算幂数因子项
        focal_weight = torch.where(torch.eq(targets, 1), 1. - inputs_softmax[:,1,:,:], inputs_softmax[:,1,:,:])
        # 得到最终的权重
        focal_weight = torch.pow(focal_weight, self.gamma)
        # focal loss 
        cls_loss = focal_weight * self.nllloss(inputs_log_softmax,targets)
        return cls_loss.sum()

希望能帮助到大家~

pytorch学习笔记 | Focal loss的原理与pytorch实现

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值