最近在做一个分类项目,发现很多“难样本”比较不好处理又特别重要,想试试FocalLoss。没找到pytorch相关实现,本来想研究pytorch的cross_entropy源码,稍微改一下(怕手残自己写的loss效率比较低),但是发现有点复杂,我的任务比较简单,改那玩意有点累。
我们知道,对于二分类:
c
r
o
s
s
_
e
n
t
r
o
p
y
(
y
,
y
^
)
=
−
y
l
o
g
y
^
−
(
1
−
y
)
l
o
g
1
−
y
^
cross\_entropy(y,\hat y) = -ylog^{\hat y}-(1-y)log^{1-\hat y}
cross_entropy(y,y^)=−ylogy^−(1−y)log1−y^
即
c
r
o
s
s
_
e
n
t
r
o
p
y
(
y
,
y
^
)
=
{
−
l
o
g
y
^
y=1
−
l
o
g
1
−
y
^
y=0
cross\_entropy(y,\hat y) =\begin{cases} -log^{\hat y}& \text{y=1}\\ -log^{1-\hat y}& \text{y=0} \end{cases}
cross_entropy(y,y^)={−logy^−log1−y^y=1y=0
y
^
\hat y
y^为模型预测概率
如果有一个正样本,模型预测结果为0.9,loss为-log(0.9)约等于0.046
还有一个正样本,模型预测结果为0.55,loss为-log(0.55)约等于0.260
这个预测为0.55的样本提供的loss是预测为0.9的样本的5.65倍
如果我把公式改成下面这样:
γ
=
2
F
o
c
a
l
L
o
s
s
(
y
,
y
^
)
=
{
−
(
1
−
y
^
)
γ
l
o
g
y
^
y=1
−
y
^
γ
l
o
g
1
−
y
^
y=0
\gamma=2\\ FocalLoss(y,\hat y) = \begin{cases} -(1-\hat y)^{\gamma}log^{\hat y}& \text{y=1}\\ -\hat y^{\gamma}log^{1-\hat y}& \text{y=0}\\ \end{cases}
γ=2FocalLoss(y,y^)={−(1−y^)γlogy^−y^γlog1−y^y=1y=0
这时如果有一个正样本,模型预测结果为0.9,loss为-0.1*0.1*log(0.9)约等于0.00046
还有一个正样本,模型预测结果为0.55,loss为-0.45*0.45*log(0.55)约等于0.0526
这个预测为0.55的样本提供的loss是预测为0.9的样本的114.35倍
这样就可以让模型更加更加关注“难样本”
另外还可以给正负样本的loss添加权重,让模型更注重正/负样本
上代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, device, gamma, alpha):
super(FocalLoss, self).__init__()
#self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.device = device
self.gamma = gamma
self.alpha = alpha
def forward(self, inputs, targets):
if self.device == 'cpu':
# 计算正负样本权重
alpha_factor = torch.ones(targets.shape) * self.alpha
alpha_factor = torch.where(torch.eq(targets, 1), alpha_factor, 1. - alpha_factor)
# 计算因子项
focal_weight = torch.where(torch.eq(targets, 1), 1. - inputs, inputs)
# 得到最终的权重
focal_weight = alpha_factor * torch.pow(focal_weight, self.gamma)
targets = targets.type(torch.FloatTensor)
# 计算标准交叉熵
bce = F.binary_cross_entropy(inputs, targets)
# focal loss
cls_loss = focal_weight * bce
else:
gpu_targets = targets.cuda()
gpu_inputs = inputs.cuda()
alpha_factor = torch.ones(gpu_targets.shape).cuda() * self.alpha
alpha_factor = torch.where(torch.eq(gpu_targets, 1), alpha_factor, 1. - alpha_factor)
focal_weight = torch.where(torch.eq(gpu_targets, 1), 1. - gpu_inputs, gpu_inputs)
focal_weight = alpha_factor * torch.pow(focal_weight, self.gamma)
targets = targets.type(torch.FloatTensor)
bce = F.binary_cross_entropy(gpu_inputs, gpu_targets)
focal_weight = focal_weight.cuda()
cls_loss = focal_weight * bce
return cls_loss.sum()
优化了一下,方便最后一层使用softmax,并且减少一些计算量
class FocalLoss(nn.Module):
def __init__(self, device, gamma, alpha):
super(FocalLoss, self).__init__()
self.device = device
self.gamma = gamma
self.w = torch.tensor([1 - alpha, alpha],dtype=torch.float32,device=self.device)
self.log_softmax = nn.LogSoftmax(dim=1)
self.nllloss = nn.NLLLoss(weight=self.w,reduction='mean')
def forward(self, inputs, targets):
inputs.to(self.device)
targets.to(self.device)
# 计算softmax,稳定算法
inputs_log_softmax = self.log_softmax(inputs)
inputs_softmax = torch.exp(inputs_log_softmax)
# 计算幂数因子项
focal_weight = torch.where(torch.eq(targets, 1), 1. - inputs_softmax[:,1,:,:], inputs_softmax[:,1,:,:])
# 得到最终的权重
focal_weight = torch.pow(focal_weight, self.gamma)
# focal loss
cls_loss = focal_weight * self.nllloss(inputs_log_softmax,targets)
return cls_loss.sum()
希望能帮助到大家~