日期:2023/12/16
论文:Focal loss for dense object detection
链接:Focal loss for dense object detection
会议:ICCV2017
参考:
[1] focal loss 通俗讲解
[2] CELoss和NLLLoss
[3] softmax 和 log_softmax
因为近期做长尾分类,所以不看目标检测部分,大部分也是参考内容中知乎里的部分
1 Binary Cross-entropy Loss
L
(
y
,
p
^
)
=
−
y
l
o
g
(
p
^
)
−
(
1
−
y
)
l
o
g
(
1
−
p
^
)
L(y,\hat{p}) = -ylog(\hat p)-(1-y)log(1-\hat p)
L(y,p^)=−ylog(p^)−(1−y)log(1−p^)
其中y为二分类任务中的0和1,表示background、foreground;
p
^
\hat p
p^为预测值与GT(既y)的接近程度,越大说明越接近GT。对于一个样本,上式也可以写为:
L
(
y
,
p
^
)
=
{
−
l
o
g
(
p
^
)
,
y
=
1
−
l
o
g
(
1
−
p
^
)
,
y
=
0
L(y,\hat p) = \left\{\begin{matrix} -log(\hat p) ,&& \space y=1 \\ -log(1-\hat p),&& \space y=0 \end{matrix}\right.
L(y,p^)={−log(p^),−log(1−p^), y=1 y=0
那么对于整个训练集,总损失则为下式。其中m和n分别表示正负样本个数,如果
m
≪
n
m\ll n
m≪n,在计算损失的时候负样本就会占主导,即使负样本的损失很小,但是数量很多的话,损失还是会向负样本倾斜,那么模型训练自然也会向负样本倾斜。
L
C
E
=
1
N
(
∑
y
i
=
1
m
l
o
g
(
p
^
)
+
∑
y
i
=
0
n
l
o
g
(
1
−
p
^
)
)
L_{CE} = \frac{1}{N}(\sum_{y_i=1}^mlog(\hat p)+\sum_{y_i=0}^nlog(1-\hat p))
LCE=N1(yi=1∑mlog(p^)+yi=0∑nlog(1−p^))
以下为原文中的一个comment:
2 Balanced CE Loss
解决这一问题的直觉方法就是给正负样本损失加个权重,既
L
C
E
=
1
N
[
∑
y
i
=
1
m
α
l
o
g
(
p
^
)
+
∑
y
i
=
0
n
(
1
−
α
)
l
o
g
(
1
−
p
^
)
]
L_{CE} = \frac{1}{N}[\sum_{y_i=1}^m\alpha log(\hat p)+\sum_{y_i=0}^n(1-\alpha)log(1-\hat p)]
LCE=N1[yi=1∑mαlog(p^)+yi=0∑n(1−α)log(1−p^)]
其中
α
\alpha
α是一个超参数,如果按照正负样本的频率进行取值的话,那就是
α
1
−
α
=
n
m
\frac{\alpha}{1-\alpha}=\frac{n}{m}
1−αα=mn。可以看到, BCE Loss其实是按照正负样本的数量在进行对数似然权重的调整。
3 Focal Loss
BCE Loss通过类频率去改变了对数似然的权重,但它并不区分样本的难易程度。Focal loss也是解决类别不平衡的一种loss,它和BCE loss的角度不一样,Focal loss的思想就是从样本预测的难易程度下手。
其原始形式是:
L
f
l
=
{
−
(
1
−
p
^
)
γ
l
o
g
(
p
^
)
,
y
=
1
−
p
^
γ
l
o
g
(
1
−
p
^
)
,
y
=
0
L_{fl} = \left\{\begin{matrix} -(1-\hat p)^\gamma log(\hat p) ,&& \space y=1 \\ -\hat p^\gamma log(1-\hat p),&& \space y=0 \end{matrix}\right. \\
Lfl={−(1−p^)γlog(p^),−p^γlog(1−p^), y=1 y=0
令
p
t
=
{
p
^
,
y
=
1
1
−
p
^
,
y
=
0
p_t = \left\{\begin{matrix} \hat p,&& y=1\\ 1-\hat p,&& y=0 \end{matrix}\right. \\
pt={p^,1−p^,y=1y=0
则focal loss可以写成统一形式:
L
f
l
=
−
(
1
−
p
t
)
γ
l
o
g
(
p
t
)
L_{fl} = -(1-p_t)^\gamma log(p_t)
Lfl=−(1−pt)γlog(pt)
同理也可以让CE Loss写成统一形式:
L
C
E
=
−
l
o
g
(
p
t
)
L_{CE} = -log(p_t)
LCE=−log(pt)
可以发现,当
γ
=
0
\gamma=0
γ=0,Focal loss就是CE loss,当
γ
≠
0
\gamma \not= 0
γ=0,Focal loss比CE loss多了一个modulating factor
(
1
−
p
t
)
γ
(1-p_t)^\gamma
(1−pt)γ。分析一下可以发现:
1)
p
t
→
1
p_t \to 1
pt→1时, modulating factor
→
0
\to0
→0,
L
f
l
→
0
L_{fl}\to0
Lfl→0。既如果样本很容易,那么对损失的贡献应该很小。
2)
p
t
→
0
p_t \to 0
pt→0时, modulating factor
→
1
\to1
→1,
L
f
l
→
L
C
E
L_{fl}\to L_{CE}
Lfl→LCE。既如果样本很难,对损失贡献很大。
所以说,当预测难度降低的时候,损失倾向于0; 当预测难度增大的时候,损失倾向于维持CE loss,这就增加了预测难度高的(hard example)样本的损失值了。其威力在论文里叙述如下所示,如果
γ
=
2
,
p
t
=
0.9
\gamma=2,p_t=0.9
γ=2,pt=0.9的话,对损失的贡献会下降100倍;
但是在实际中,会使用 α \alpha α平衡变体,因为实验发现相比于non- α \alpha α会提升一些精度,并且使用sigmoid函数去计算 p p p会使得数值产生稳定性
最后的focal loss形式如下👇
L
f
l
=
−
α
t
(
1
−
p
t
)
γ
l
o
g
(
p
t
)
L_{fl} = -\alpha_t(1-p_t)^\gamma log(p_t)
Lfl=−αt(1−pt)γlog(pt)
4 Focal loss 和 CE loss各种探讨
- 这幅图的横坐标指的是预测的概率,由以上分析可知,预测概率高(容易预测)的其损失值应该小,而预测概率低的(hard example)其损失值应该跟CE loss一样。
- 可以发现focal loss相比于CE loss,在well-classified examples上的损失值都要小,并且 γ \gamma γ的作用就是让损失值曲线更陡峭。
- 左边的图主要比较的是BCE loss中 α \alpha α因子的影响,可以发现当 α = 0.75 \alpha=0.75 α=0.75时的效果最好,既正样本给75%的权重,负样本给25%的权重。但是应该跟数据集有关吧。
- 右边的图是Focal loss的两个超参的消融实验,可以发现最佳设置为 ( γ = 2. , α = 0.25 ) (\gamma=2., \alpha=0.25) (γ=2.,α=0.25)
- 知乎上还看到一个有意思的解读,就是 OHEM的AP要比OHEM 1:3的AP要好,这就说明类的个数的不平衡并不是影响实验效果的主要,预测难易程度样本的不平衡才是关键。
- 这是在正负样本上对 γ \gamma γ的一些消融实验
- 左图可以说明,大约20%的正样本占了累积loss的一半, γ \gamma γ变大,也只会使这一数值好一点点,但是影响不大,说明对于hard examples来说,它是更接近与CE Loss的,既modulating factor接近于1。
- 右图可以说明,
γ
\gamma
γ的增大对负样本loss的影响很大,说明让更多容易预测的负样本的损失值变小了。
- 这里 x t > 0 x_t>0 xt>0就说明 p t > 0.5 p_t>0.5 pt>0.5(论文里有假设),由梯度图可以看出,对于FL及FL*(FL的变体)而言,当>0时,他们的梯度都很快趋近于0;这不同于CE Loss,CE loss在>0时还是有很大一部分梯度存在的。
5 代码
以下为3个github库的参考:
代码及注释:
1) focal loss
import torch
from torch import nn
from torch.nn import functional as F
class FocalLoss(nn.Module):
def __init__(self,alpha=None,gamma=2,num_classes=1000,size_average=True):
'''
Focal Loss = -\alpha*(1-y_i)**\gamma*CE_Loss
:param alpha : 类别权重,1)为列表时,为各类别权重; 2)为常数时,类别权重为[\alpha,1-\alpha,1-\alpha...] 抑制背景类
:param gamma : 难易样本调节因子,默认为2
:param num_classes : 类别数
:param size_average : 默认取平均
:attention: 关于alpha是这样设置的:
1) 如果传入的alpha为None,则默认alpha不存在,既为常数1
2) 如果传入的alpha为int或float,则设置为[alpha,1-alpha,1-alpha....],默认第一类为头类(背景)
3) 如果传入的alpha为list,判断是否等于类别大小,然后各个类对应alpha
'''
super().__init__()
self.size_average = size_average
self.alpha = alpha
if isinstance(alpha,list):
assert len(alpha) == num_classes
self.alpha = torch.Tensor(alpha)
elif isinstance(alpha,(int,float)):
self.alpha = torch.zeros(num_classes)
self.alpha[0] = alpha
self.alpha[1:] = 1-alpha
elif alpha is None:
self.alpha = torch.ones(num_classes)
self.gamma = gamma
# print('Focal Loss:\n')
# print('- Alpha: ',self.alpha)
# print('\n- Gamma: ',self.gamma)
def forward(self,preds,labels):
'''
损失计算, B为batch size, C为类别数
:param preds: 分类为[B,C]
:param labels: 分类为[B]
:return: loss
'''
if preds.dim()>2: # 一般来说分类任务是[B,C] 既每张图一个类别
preds = preds.view(preds.size(0),preds.size(1),-1) # N,C,H,W -> N,C,H*W; 假如[128,1000,8,4] ->[128,1000,32]
preds = preds.transpose(1,2) # N,C,H*W -> N,H*W,C ; [128,1000,32] -> [128,32,1000]
preds = preds.contiguous().view(-1,preds.size(2)) # N*H*W,C ; [4096,1000]
preds = preds.view(-1,preds.size(-1)) # [B,C]
preds_logsoft = F.log_softmax(preds,dim=1) # log_softmax:先softmax后再log
preds_softmax = torch.exp(preds_logsoft) # e*(log*softmax) = softmax 概率p
p_t = preds_softmax.gather(1,labels.view(-1,1)) # GT标签对应的概率
preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1)) # CE Loss = softmax + log + nll loss
alpha = self.alpha.to(preds.device)
alpha = alpha.gather(0,labels.view(-1))
loss = -torch.mul(torch.pow((1-p_t),self.gamma),preds_logsoft)
loss = torch.mul(alpha,loss.t())
if self.size_average:
return loss.mean()
else:
return loss.sum()
2) demo
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import os,sys,random,time
import argparse
from Focal_loss import FocalLoss
start_time = time.time()
maxe = 0
for i in range(1000):
x = torch.rand(12800,2)*random.randint(1,10)
x = Variable(x.cuda())
l = torch.rand(12800).ge(0.1).long()
l = Variable(l.cuda())
output0 = FocalLoss(gamma=0)(x,l)
output1 = nn.CrossEntropyLoss()(x,l)
a = output0.item()
b = output1.item()
if abs(a-b)>maxe: maxe = abs(a-b)
print('time:',time.time()-start_time,'max_error:',maxe)
start_time = time.time()
maxe = 0
for i in range(100):
x = torch.rand(128,1000,8,4)*random.randint(1,10)
x = Variable(x.cuda())
l = torch.rand(128,8,4)*1000 # 1000 is classes_num
l = l.long()
l = Variable(l.cuda())
output0 = FocalLoss(gamma=0)(x,l)
output1 = nn.NLLLoss2d()(F.log_softmax(x),l)
a = output0.item()
b = output1.item()
if abs(a-b)>maxe: maxe = abs(a-b)
print('time:',time.time()-start_time,'max_error:',maxe)