tensorflow分类的loss函数_图解Focal Loss以及Tensorflow实现(二分类、多分类)

总体上讲,Focal Loss是一个缓解分类问题中类别不平衡、难易样本不均衡的损失函数。首先看一下论文中的这张图:解释:横轴是ground truth类别对应的概率(经过sigmoid/softmax处理过的logits),纵轴是对应的loss值;蓝色的线(gamma=0),就是原始交叉熵损失函数,可以明显看出ground truth的概率越大,loss越小,符合常识;除了蓝色的线,其他几个都是F...
摘要由CSDN通过智能技术生成

总体上讲,Focal Loss是一个缓解分类问题中类别不平衡、难易样本不均衡的损失函数。首先看一下论文中的这张图:

解释:

横轴是ground truth类别对应的概率(经过sigmoid/softmax处理过的logits),纵轴是对应的loss值;

蓝色的线(gamma=0),就是原始交叉熵损失函数,可以明显看出ground truth的概率越大,loss越小,符合常识;

除了蓝色的线,其他几个都是Focal Loss的线,其实原始交叉熵损失函数是Focal Loss的特殊版本(gamma=0)

其他几个Focal Loss线都在蓝色下边,可以看出Focal Loss的作用就是【衰减】;

从图中可以看出,ground truth的概率越大(即容易分类的简单样本),衰减越厉害,也就是大大降低了简单样本的loss;

从图中可以看出,ground truth的概率越小(即不易分类的困难样本),也是有衰减的,但是衰减的程度比较小;

下边是我自己模拟的一组数据,一组固定的logits=[0+epsilon, 0.1, 0.2, ..., 0.9, 1.0-epsilon],然后假设ground truth分别是0、1、2、...、9、10的时候,gamma=0、0.5、1、2、...、8、16对应的loss。

例如第3行第1列的2.75表示,ground truth是类别2,即对应的logits是0.2,gamma=0的时候,loss=2.75(gamma=0,就是原始的多分类交叉熵)。

根据上表可以得到下边的图:

从上图可以看出,随着gamma增大,整体loss都下降了,但是logits相对越高(这个例子中最大logits=1),下降的倍数越大。从上表的最后一列也可以看出来,gamma=0和gamma=16的时候,logits=0只衰减了2倍,但是logits=1衰减了16倍

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
下面是 circle loss 和 focal loss 的简单实现(包括二分类多分类): Circle Loss: ```python import torch import torch.nn as nn import torch.nn.functional as F class CircleLoss(nn.Module): def __init__(self, m=0.25, s=30): super(CircleLoss, self).__init__() self.m = m self.s = s def forward(self, feats, labels): sim_mat = torch.matmul(feats, feats.t()) mask = labels.expand(labels.size(0), labels.size(0)).t().eq(labels.expand(labels.size(0), labels.size(0))) pos_mask = mask.triu(diagonal=1) neg_mask = mask.logical_not().triu(diagonal=1) pos_sim = sim_mat[pos_mask] neg_sim = sim_mat[neg_mask] alpha_p = F.relu(-pos_sim.detach() + 1 + self.m) alpha_n = F.relu(neg_sim.detach() + self.m) delta_p = 1 - self.m delta_n = self.m logit_p = -self.s * alpha_p * (pos_sim - delta_p) logit_n = self.s * alpha_n * (neg_sim - delta_n) logit = torch.cat([logit_p, logit_n], dim=0) loss = F.softplus(torch.logsumexp(logit, dim=0)) return loss ``` Focal Loss: ```python import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, alpha=0.25, gamma=2, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss if self.reduction == 'mean': return torch.mean(F_loss) elif self.reduction == 'sum': return torch.sum(F_loss) else: return F_loss ``` 以上代码适用于 PyTorch 深度学习框架。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值