Focal loss 基于pytorch实现.

Focal_loss

作者是一名深度学习工程师,主要研究计算机视觉与三维点云处理,作者Github:Github.欢迎多多交流.

同时作者也是pytorch中文网的维护者之一,网站会不定时更新深度学习相关文章,欢迎大家多多支持.


retinanet是ICCV2017的Best Student Paper Award(最佳学生论文),何凯明是其作者之一.文章中最为精华的部分就是损失函数 Focal loss的提出.

论文中提出类别失衡是造成two-stage与one-stage模型精确度差异的原因.并提出了Focal loss损失函数,通过调整类间平衡因子与难易度平衡因子.最终使one-stage模型达到了two-stage的精确度.


本项目基于pytorch实现focal loss,力图给你原生pytorch损失函数的使用体验.

一. 项目简介

实现过程简易明了,全中文备注.

阿尔法α 参数用于调整类别权重

伽马γ 参数用于调整不同检测难易样本的权重,让模型快速关注于困难样本

完整项目地址:Github,欢迎star, fork. github还有其他视觉相关项目

github连接较慢的,可以去Gitee(国内的代码托管网站),也有完整项目.

项目配有 Jupyter-Notebook 作为focal loss使用例子.

二. 损失函数公式

focal loss 损失函数基于交叉熵损失函数,在交叉熵的基础上,引入了α与γ两个不同的调整因子.

2.1 交叉熵损失

2.2 带平衡因子的交叉熵

2.3 Focal loss损失

加入 (1-pt)γ 平衡难易样本的权重,通过γ缩放因子调整,retainnet默认γ=2

2.4 带平衡因子的Focal损失

论文中最终为带平衡因子的focal loss, 本项目实现的也是这个版本


三. Focal loss实现

# -*- coding: utf-8 -*-
# @Author  : LG
from torch import nn
import torch
from torch.nn import functional as F

class focal_loss(nn.Module):    
    def __init__(self, alpha=0.25, gamma=2, num_classes = 3, size_average=True):
        """
        focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)      
        步骤详细的实现了 focal_loss损失函数.
        :param alpha:   阿尔法α,类别权重.      当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.25
        :param gamma:   伽马γ,难易样本调节参数. retainnet中设置为2
        :param num_classes:     类别数量
        :param size_average:    损失计算方式,默认取均值
        """
        
        super(focal_loss,self).__init__()
        self.size_average = size_average
        if isinstance(alpha,list):
            assert len(alpha)==num_classes   # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
            print("Focal_loss alpha = {}, 将对每一类权重进行精细化赋值".format(alpha))
            self.alpha = torch.Tensor(alpha)
        else:
            assert alpha<1   #如果α为一个常数,则降低第一类的影响,在目标检测中为第一类
            print(" --- Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用 --- ".format(alpha))
            self.alpha = torch.zeros(num_classes)
            self.alpha[0] += alpha
            self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]
        self.gamma = gamma
        
    def forward(self, preds, labels):
        """
        focal_loss损失计算        
        :param preds:   预测类别. size:[B,N,C] or [B,C]    分别对应与检测与分类任务, B 批次, N检测框数, C类别数        
        :param labels:  实际类别. size:[B,N] or [B]        
        :return:
        """        
        # assert preds.dim()==2 and labels.dim()==1        
        preds = preds.view(-1,preds.size(-1))        
        self.alpha = self.alpha.to(preds.device)        
        preds_softmax = F.softmax(preds, dim=1) # 这里并没有直接使用log_softmax, 因为后面会用到softmax的结果(当然你也可以使用log_softmax,然后进行exp操作)        
        preds_logsoft = torch.log(preds_softmax)
        preds_softmax = preds_softmax.gather(1,labels.view(-1,1))   # 这部分实现nll_loss ( crossempty = log_softmax + nll )        
        preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))        
        self.alpha = self.alpha.gather(0,labels.view(-1))        
        loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft)  # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ
        loss = torch.mul(self.alpha, loss.t())        
        if self.size_average:        
            loss = loss.mean()        
        else:            
            loss = loss.sum()        
        return loss

详细的使用例子请到Github查看jupyter-notebook.

说明

完整项目地址:Github,欢迎star, fork.

仅限用于交流学习,如需引用,请联系作者.

  • 5
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值