文章目录
前言
CrossEntropyLoss和FocalLoss的学习记录。
以往的教程多以二分类为例,而忽略了多分类情况,本文旨在弥补这一空白。
此外,在目标检测领域中,多分类情况还会被进一步推广为多分类前后景情况,这种特殊的情况,鲜有人讨论,给许多人造成了困惑。
本文会简略地回顾公式推导和二分类情况,重点放在对多分类前后景的情况讨论,并且还补充了一些SOTA网络对于多分类前后景情况的实践细节以及个人理解和讨论。
一、背景知识回顾
CrossEntropyLoss
CrossEntropyLoss公式:
L
C
E
=
1
N
∑
i
N
∑
j
C
−
q
(
i
,
j
)
log
p
(
i
,
j
)
{L}_{CE} = \frac{1}{N} \sum_{i}^{N} \sum_{j}^{C} -q(i,j) \log p(i,j)
LCE=N1i∑Nj∑C−q(i,j)logp(i,j)
其中,
N
N
N表示样本个数。
C
C
C表示类别数。
q
q
q形状为
(
N
,
C
)
(N, C)
(N,C),
q
(
i
,
j
)
q(i,j)
q(i,j)表示样本
i
i
i是类别
c
c
c的真值概率。
p
p
p形状为
(
N
,
C
)
(N, C)
(N,C),
p
(
i
,
j
)
p(i,j)
p(i,j)表示样本
i
i
i是类别
c
c
c的预测概率,通常是经过softmax函数激活过,限制
p
(
i
,
j
)
∈
[
0
,
1
]
p(i, j) \in [0,1]
p(i,j)∈[0,1],并满足
∑
j
C
p
(
i
,
j
)
=
1
\sum_j^Cp(i,j)=1
∑jCp(i,j)=1。
FocalLoss原论文解析
本小节的符号均与原论文一致,与本文其他章节的符号体系有区别,请注意加以区分
根据原文Focal Loss for Dense Object Detection,首先以二分类任务为例介绍Focal Loss,文中首先将CrossEntropy Loss定义成如下形式:注意这里没有计算多样本的均值
C
E
(
p
,
y
)
=
{
−
log
(
p
)
if
y
=
1
−
log
(
1
−
p
)
otherwise
CE(p,y)=\begin{cases}-\log(p)&\text{if~}y=1\\-\log(1-p)&\text{otherwise}&\end{cases}
CE(p,y)={−log(p)−log(1−p)if y=1otherwise
其中,
y
∈
{
±
1
}
y\in\{\pm1\}
y∈{±1}表示真值类别,
p
∈
[
0
,
1
]
p\in[0,1]
p∈[0,1]表示预测为
y
=
1
y=1
y=1的概率。
进一步地,为了方便符号标记,定义
p
t
p_t
pt:
p
t
=
{
p
if
y
=
1
1
−
p
otherwise
p_t=\begin{cases}p&\text{if~} y=1\\1-p&\text{otherwise}\end{cases}
pt={p1−pif y=1otherwise
因此,CrossEntropy Loss被简化成如下形式:
C
E
(
p
,
y
)
=
C
E
(
p
t
)
=
−
log
(
p
t
)
CE(p,y)=CE(p_t)=-\log(p_t)
CE(p,y)=CE(pt)=−log(pt)
为了平衡不同类别的权重,需要引入类别权重,
α
∈
[
0
,
1
]
\alpha \in [0, 1]
α∈[0,1]为
y
=
1
y=1
y=1的权重,
1
−
α
1-\alpha
1−α为
y
=
−
1
y=-1
y=−1的权重,引入权重后的CrossEntropy Loss表示为:
C
E
(
p
t
)
=
−
α
t
log
(
p
t
)
CE(p_t)=-\alpha_t\log(p_t)
CE(pt)=−αtlog(pt)
由于
α
t
\alpha_t
αt本质上是对正负样本对于loss的贡献平衡进行了缓解,但并未处理难易样本的贡献平衡,因此提出了Focal Loss:
F
L
(
p
t
)
=
−
α
t
(
1
−
p
t
)
γ
log
(
p
t
)
FL(p_t) = -\alpha_t (1 - p_t)^{\gamma} \log(p_t)
FL(pt)=−αt(1−pt)γlog(pt)
其中
α
t
\alpha_t
αt为正负样本权重调节因子,
γ
≥
0
\gamma\geq0
γ≥0为难易样本权重调节因子。
结合Facebook开源代码fvcore:
def sigmoid_focal_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = -1,
gamma: float = 2,
reduction: str = "none",
) -> torch.Tensor:
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
Returns:
Loss tensor with the reduction option applied.
"""
inputs = inputs.float()
targets = targets.float()
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss
FocalLoss推广
在分析原文和代码的基础上,我们采用统一的符号系统,重构Focal Loss的公式并推广至一般形式:
L
F
o
c
a
l
=
1
N
∑
i
N
∑
j
C
−
α
(
c
)
[
1
−
q
(
i
,
j
)
p
(
i
,
j
)
]
γ
(
c
)
q
(
i
,
j
)
log
p
(
i
,
j
)
{L}_{Focal} = \frac{1}{N} \sum_{i}^{N} \sum_{j}^{C} -\alpha(c)[1-q(i,j)p(i,j)]^{\gamma(c)} q(i,j) \log p(i,j)
LFocal=N1i∑Nj∑C−α(c)[1−q(i,j)p(i,j)]γ(c)q(i,j)logp(i,j)
其中,
N
N
N表示样本个数。
C
C
C表示类别数。
α
\alpha
α形状为
(
C
,
)
(C,)
(C,),
α
(
c
)
\alpha(c)
α(c)表示类别c的正负样本权重调节因子。
γ
\gamma
γ形状为
(
C
,
)
(C,)
(C,),
γ
(
c
)
\gamma(c)
γ(c)表示类别c的难易样本权重调节因子。
q
q
q形状为
(
N
,
C
)
(N, C)
(N,C),
q
(
i
,
j
)
q(i,j)
q(i,j)表示样本
i
i
i是类别
c
c
c的真值概率。
p
p
p形状为
(
N
,
C
)
(N, C)
(N,C),
p
(
i
,
j
)
p(i,j)
p(i,j)表示样本
i
i
i是类别
c
c
c的预测概率,通常是经过softmax函数激活过,限制
p
(
i
,
j
)
∈
[
0
,
1
]
p(i, j) \in [0,1]
p(i,j)∈[0,1],并满足
∑
j
C
p
(
i
,
j
)
=
1
\sum_j^Cp(i,j)=1
∑jCp(i,j)=1。
二、二分类任务中的实践细节
CrossEntropyLoss
在hard label条件下,
q
q
q中元素仅包含0和1。可以将CrossEntropyLoss公式进行以下简化:
L
C
E
=
1
N
∑
i
N
−
[
y
(
i
)
log
x
(
i
)
+
(
1
−
y
(
i
)
)
log
(
1
−
x
(
i
)
)
]
{L}_{CE} = \frac{1}{N} \sum_{i}^{N} -[y(i) \log x(i)+(1-y(i)) \log (1-x(i))]
LCE=N1i∑N−[y(i)logx(i)+(1−y(i))log(1−x(i))]
其中,
y
y
y表示标签,形状为
(
N
,
)
(N,)
(N,),
y
(
i
)
y(i)
y(i)表示样本
i
i
i的标签,在二分类中
y
(
i
)
=
0
y(i)=0
y(i)=0或
y
(
i
)
=
1
y(i)=1
y(i)=1。
x
x
x表示预测为
y
(
i
)
=
1
y(i)=1
y(i)=1概率,形状为
(
N
,
)
(N,)
(N,),
x
(
i
)
x(i)
x(i)表示样本
i
i
i的预测概率,通常是经过sigmoid函数激活过,限制
x
(
i
)
∈
[
0
,
1
]
x(i) \in [0,1]
x(i)∈[0,1]。
接着使用代码进行实验验证,验证简化、原公式以及pytorch官方API的一致性:
import torch
from torch import nn
import torch.nn.functional as F
# 模拟二分类网络输出层结果和标签
pred = torch.randn(3)
'''
tensor([ 0.9764, -0.3312, -0.8662])
'''
target = torch.tensor([1, 0, 1])
'''
tensor([1, 0, 1])
'''
# 简化公式计算
x = pred.sigmoid()
'''
tensor([0.7264, 0.4180, 0.2960])
'''
y = target.float()
'''
tensor([1., 0., 1.])
'''
bce1_none = -(y*torch.log(x)+(1-y)*torch.log(1-x))
'''
tensor([0.3197, 0.5412, 1.2173])
'''
bce1 = bce1_none.mean()
'''
tensor(0.6927)
'''
# 原公式计算
p = torch.stack([1-x, x], dim=1)
'''
tensor([[0.2736, 0.7264],
[0.5820, 0.4180],
[0.7040, 0.2960]])
'''
q = F.one_hot(y, num_classes=2)
'''
tensor([[0, 1],
[1, 0],
[0, 1]])
'''
bce2_none = (-q*torch.log(p)).sum(dim=1)
'''
tensor([0.3197, 0.5412, 1.2173])
'''
bce2 = bce2_none.mean()
'''
tensor(0.6927)
'''
# pytorch API: F.binary_cross_entropy
bce3_none = bce3_none = F.binary_cross_entropy(pred.sigmoid(), target.float(), reduction='none')
'''
tensor([0.3197, 0.5412, 1.2173])
'''
bce3 = bce3_none.mean()
'''
tensor(0.6927)
'''
#pytorch API: binary_cross_entropy_with_logits
bce4_none = F.binary_cross_entropy_with_logits(pred, target.float(), reduction='none')
'''
tensor([0.3197, 0.5412, 1.2173])
'''
bce4 = bce4_none.mean()
'''
tensor(0.6927)
'''
FocalLoss
在hard label条件下,
q
q
q中元素仅包含0和1。可以将FocalLoss公式进行以下简化:
L
C
E
=
1
N
∑
i
N
−
α
(
i
)
[
1
−
p
t
(
i
)
]
γ
[
y
(
i
)
log
x
(
i
)
+
(
1
−
y
(
i
)
)
log
(
1
−
x
(
i
)
)
]
α
(
i
)
=
y
(
i
)
α
+
(
1
−
y
(
i
)
)
(
1
−
α
)
p
t
(
i
)
=
y
(
i
)
x
(
i
)
−
[
1
−
y
(
i
)
]
[
1
−
x
(
i
)
]
{L}_{CE} = \frac{1}{N} \sum_{i}^{N} -\alpha(i)[1-p_t(i)]^\gamma [y(i) \log x(i)+(1-y(i)) \log (1-x(i))] \newline \alpha(i)=y(i)\alpha+(1-y(i))(1-\alpha) \newline p_t(i)=y(i)x(i)-[1-y(i)][1-x(i)]
LCE=N1i∑N−α(i)[1−pt(i)]γ[y(i)logx(i)+(1−y(i))log(1−x(i))]α(i)=y(i)α+(1−y(i))(1−α)pt(i)=y(i)x(i)−[1−y(i)][1−x(i)]
其中,
α
\alpha
α和
γ
\gamma
γ为常数,
α
\alpha
α表示
y
(
i
)
=
1
y(i)=1
y(i)=1的正负样本权重调节因子,
y
y
y表示标签,形状为
(
N
,
)
(N,)
(N,),
y
(
i
)
y(i)
y(i)表示样本
i
i
i的标签,在二分类中
y
(
i
)
=
0
y(i)=0
y(i)=0或
y
(
i
)
=
1
y(i)=1
y(i)=1。
x
x
x表示预测为
y
(
i
)
=
1
y(i)=1
y(i)=1概率,形状为
(
N
,
)
(N,)
(N,),
x
(
i
)
x(i)
x(i)表示样本
i
i
i的预测概率,通常是经过sigmoid函数激活过,限制
x
(
i
)
∈
[
0
,
1
]
x(i) \in [0,1]
x(i)∈[0,1]。
接着使用代码进行实验验证,验证简化、原公式以及pytorch官方API的一致性:
import torch
from torch import nn
import torch.nn.functional as F
# 模拟二分类网络输出层结果和标签
pred = torch.randn(3)
'''
tensor([ 0.9764, -0.3312, -0.8662])
'''
target = torch.tensor([1, 0, 1])
'''
tensor([1, 0, 1])
'''
# 设置超参数alpha=0.25, gamma=2.0
alpha = 0.25
gamma = 2.0
# 简化公式计算
x = pred.sigmoid()
'''
tensor([0.7264, 0.4180, 0.2960])
'''
y = target.float()
'''
tensor([1., 0., 1.])
'''
alpha_t = y * alpha + (1 - y) * (1 - alpha)
'''
tensor([0.2500, 0.7500, 0.2500])
'''
pt = y * x + (1 - y) * (1 - x)
'''
tensor([0.7264, 0.5820, 0.2960])
'''
bce = -(y*torch.log(x)+(1-y)*torch.log(1-x))
# bce = F.binary_cross_entropy(x, y, reduction='none')
# bce = F.binary_cross_entropy_with_logits(pred, y, reduction='none')
# bce = -torch.log(pt)
'''
tensor([0.3197, 0.5412, 1.2173])
'''
focal_loss1_none = alpha_t * ((1 - pt) ** gamma) * bce
'''
tensor([0.0060, 0.0709, 0.1508])
'''
focal_loss1 = focal_loss1_none.mean()
'''
tensor(0.0759)
'''
# 原公式计算
alphas = y * alpha + (1 - y) * (1 - alpha)
'''
tensor([0.2500, 0.7500, 0.2500])
'''
gammas = y * gamma + (1 - y) * gamma
'''
tensor([2., 2., 2.])
'''
p = torch.stack([1-x, x], dim=1)
'''
tensor([[0.2736, 0.7264],
[0.5820, 0.4180],
[0.7040, 0.2960]])
'''
q = F.one_hot(y, num_classes=2)
'''
tensor([[0, 1],
[1, 0],
[0, 1]])
'''
focal_loss2_none = (-alphas.reshape(-1, 1) * ((1 - q * p) ** gammas.reshape(-1, 1)) * q * torch.log(p)).sum(dim=1)
'''
tensor([0.0060, 0.0709, 0.1508])
'''
focal_loss2 = focal_loss2_none.mean()
'''
tensor(0.0759)
'''
# torchvision API: sigmoid_focal_loss
from torchvision.ops import sigmoid_focal_loss
focal_loss3_none = sigmoid_focal_loss(pred, target.float(), alpha, gamma, reduction='none')
'''
tensor([0.0060, 0.0709, 0.1508])
'''
focal_loss3 = focal_loss3_none.mean()
'''
tensor(0.0759)
'''
# fvcore API: sigmoid_focal_loss
from fvcore.nn import sigmoid_focal_loss
focal_loss4_none = sigmoid_focal_loss(pred, target.float(), alpha, gamma, reduction='none')
'''
tensor([0.0060, 0.0709, 0.1508])
'''
focal_loss4 = focal_loss4_none.mean()
'''
tensor(0.0759)
'''
三、多分类任务中的实践细节
多分类任务重,预测结果 p r e d pred pred的形状为 ( N , C ) (N,C) (N,C),在hard label情况下,真值标签 t a r g e t target target的形状为 ( N , ) (N,) (N,)。
CrossEntropyLoss
import torch
from torch import nn
import torch.nn.functional as F
pred = torch.randn(3, 5)
'''
tensor([[ 1.0865, -0.6392, 0.0881, -0.5137, 1.4306],
[ 0.0467, -0.4962, -1.5786, 2.0470, 0.8474],
[-0.6024, -0.4759, 1.3952, -0.1280, 0.6135]])
'''
target = torch.tensor([1, 0, 4])
'''
tensor([1, 0, 4])
'''
# 原公式计算
p = pred.softmax(dim=1)
'''
tensor([[0.3166, 0.0564, 0.1166, 0.0639, 0.4466],
[0.0877, 0.0510, 0.0173, 0.6486, 0.1954],
[0.0690, 0.0783, 0.5088, 0.1109, 0.2329]])
'''
q = F.one_hot(target, num_classes=5)
'''
tensor([[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 0, 0, 0, 1]])
'''
ce_loss1_none = (-q * torch.log(p)).sum(dim=1)
'''
tensor([2.8760, 2.4333, 1.4573])
'''
# pytorch API: cross_entropy
ce_loss2_none = F.cross_entropy(pred, target, reduction='none')
'''
tensor([2.8760, 2.4333, 1.4573])
'''
# pytorch API: nll_loss
ce_loss3_none = F.nll_loss(torch.log(p), target, reduction='none')
'''
tensor([2.8760, 2.4333, 1.4573])
'''
FocalLoss
import torch
from torch import nn
import torch.nn.functional as F
pred = torch.randn(3, 5)
'''
tensor([[ 1.0865, -0.6392, 0.0881, -0.5137, 1.4306],
[ 0.0467, -0.4962, -1.5786, 2.0470, 0.8474],
[-0.6024, -0.4759, 1.3952, -0.1280, 0.6135]])
'''
target = torch.tensor([1, 0, 4])
'''
tensor([1, 0, 4])
'''
# 设置超参数alpha=0.25, gamma=2.0
alpha = torch.tensor([0.25, 0.75, 0.75, 0.75, 0.75])
gamma = torch.tensor([2., 2., 2., 2., 2.])
# 原公式计算
p = pred.softmax(dim=1)
'''
tensor([[0.3166, 0.0564, 0.1166, 0.0639, 0.4466],
[0.0877, 0.0510, 0.0173, 0.6486, 0.1954],
[0.0690, 0.0783, 0.5088, 0.1109, 0.2329]])
'''
q = F.one_hot(target, num_classes=5)
'''
tensor([[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 0, 0, 0, 1]])
'''
alpha_gathered = torch.gather(alpha, 0, target)
'''
tensor([0.7500, 0.2500, 0.7500])
'''
gamma_gathered = torch.gather(gamma, 0, target)
'''
tensor([2., 2., 2.])
'''
focal_loss1_none = (-alpha_gathered.reshape(-1, 1) * ((1 - (q * p)) ** gamma_gathered.reshape(-1, 1)) * (q * torch.log(p))).sum(dim=1)
'''
tensor([1.9208, 0.5063, 0.6432])
'''
# 目前未有官方API,参考https://github.com/yatengLG/Focal-Loss-Pytorch/blob/master/Focal_Loss.py
class focal_loss(nn.Module):
def __init__(self, alpha=None, gamma=2, num_classes = 3, size_average=True):
"""
focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)
步骤详细的实现了 focal_loss损失函数.
:param alpha: 阿尔法α,类别权重. 当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.25
:param gamma: 伽马γ,难易样本调节参数. retainnet中设置为2
:param num_classes: 类别数量
:param size_average: 损失计算方式,默认取均值
"""
super(focal_loss,self).__init__()
self.size_average = size_average
if alpha is None:
self.alpha = torch.ones(num_classes)
elif isinstance(alpha,list):
assert len(alpha)==num_classes # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
self.alpha = torch.Tensor(alpha)
else:
assert alpha<1 #如果α为一个常数,则降低第一类的影响,在目标检测中第一类为背景类
self.alpha = torch.zeros(num_classes)
self.alpha[0] += alpha
self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]
self.gamma = gamma
print('Focal Loss:')
print(' Alpha = {}'.format(self.alpha))
print(' Gamma = {}'.format(self.gamma))
def forward(self, preds, labels):
"""
focal_loss损失计算
:param preds: 预测类别. size:[B,N,C] or [B,C] 分别对应与检测与分类任务, B 批次, N检测框数, C类别数
:param labels: 实际类别. size:[B,N] or [B]
:return:
"""
# assert preds.dim()==2 and labels.dim()==1
preds = preds.view(-1,preds.size(-1))
alpha = self.alpha.to(preds.device)
preds_logsoft = F.log_softmax(preds, dim=1) # log_softmax
preds_softmax = torch.exp(preds_logsoft) # softmax
preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) # 这部分实现nll_loss ( crossempty = log_softmax + nll )
preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
alpha = self.alpha.gather(0,labels.view(-1))
loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft) # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ
loss = torch.mul(alpha, loss.t())
if self.size_average:
loss = loss.mean()
return loss
focal_loss2_none = focal_loss(alpha=0.25, gamma=2, num_classes=5, size_average=False)(pred, target)
'''
tensor([[1.9208, 0.5063, 0.6432]])
'''
四、SOTA网络中的多分类前后景
VirConv
VirConv是一个基于相机和激光雷达的3D目标检测网络,采用伪点思想。
源码
class SigmoidFocalClassificationLoss(nn.Module):
"""
Sigmoid focal cross entropy loss.
"""
def __init__(self, gamma: float = 2.0, alpha: float = 0.25):
"""
Args:
gamma: Weighting parameter to balance loss for hard and easy examples.
alpha: Weighting parameter to balance loss for positive and negative examples.
"""
super(SigmoidFocalClassificationLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
@staticmethod
def sigmoid_cross_entropy_with_logits(input: torch.Tensor, target: torch.Tensor):
""" PyTorch Implementation for tf.nn.sigmoid_cross_entropy_with_logits:
max(x, 0) - x * z + log(1 + exp(-abs(x))) in
https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
Args:
input: (B, #anchors, #classes) float tensor.
Predicted logits for each class
target: (B, #anchors, #classes) float tensor.
One-hot encoded classification targets
Returns:
loss: (B, #anchors, #classes) float tensor.
Sigmoid cross entropy loss without reduction
"""
loss = torch.clamp(input, min=0) - input * target + \
torch.log1p(torch.exp(-torch.abs(input)))
return loss
def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor):
"""
Args:
input: (B, #anchors, #classes) float tensor.
Predicted logits for each class
target: (B, #anchors, #classes) float tensor.
One-hot encoded classification targets
weights: (B, #anchors) float tensor.
Anchor-wise weights.
Returns:
weighted_loss: (B, #anchors, #classes) float tensor after weighting.
"""
pred_sigmoid = torch.sigmoid(input)
alpha_weight = target * self.alpha + (1 - target) * (1 - self.alpha)
pt = target * (1.0 - pred_sigmoid) + (1.0 - target) * pred_sigmoid
focal_weight = alpha_weight * torch.pow(pt, self.gamma)
bce_loss = self.sigmoid_cross_entropy_with_logits(input, target)
loss = focal_weight * bce_loss
if weights.shape.__len__() == 2 or \
(weights.shape.__len__() == 1 and target.shape.__len__() == 2):
weights = weights.unsqueeze(-1)
assert weights.shape.__len__() == loss.shape.__len__()
return loss * weights
在VirConv的实践过程中,会面临背景类+多类的情况,会令人疑惑的是背景类的标签为 0 0 0,多类标签为 1 , . . . , C 1, ..., C 1,...,C,而预测类别的head所输出的张量维度是 C C C,真值标签中包含 0 0 0,此时如何使用FocalLoss计算损失?下面进行逐步分析:
- VirConv-V使用VoxelRCNN的框架,一阶段网络使用一张形状为 ( B , 256 , H , W ) (B, 256, H, W) (B,256,H,W)的BEV视角下的特征图,并结合预设的anchor,提出roi。
- 在经过预测roi类别的head后,得到形状为 ( B , N , C ) (B, N, C) (B,N,C)的张量,也就是上面代码中的input。
- 上面代码中的target表示真值的onehot向量,生成方法如下面代码所示,其中cls_targets表示真值标签包含 0 , 1 , . . . , C 0, 1, ..., C 0,1,...,C,然后利用真值标签生成onehot向量,形状为 ( B , N , C + 1 ) (B, N, C+1) (B,N,C+1),值得注意的是:后续会将onehot向量的第0维抛弃,从而形状变为 ( B , N , C ) (B, N, C) (B,N,C)。
cls_targets = box_cls_labels * cared.type_as(box_cls_labels)
cls_targets = cls_targets.unsqueeze(dim=-1)
cls_targets = cls_targets.squeeze(dim=-1)
one_hot_targets = torch.zeros(
*list(cls_targets.shape), self.num_class + 1, dtype=cls_preds.dtype, device=cls_targets.device
)
one_hot_targets.scatter_(-1, cls_targets.unsqueeze(dim=-1).long(), 1.0)
cls_preds = cls_preds.view(batch_size, -1, self.num_class)
one_hot_targets = one_hot_targets[..., 1:]
- 在进入FocalLoss的计算后,首先会将input用sigmoid函数激活。
- 计算alpha_weight与原公式一致。
alpha_weight = target * self.alpha + (1 - target) * (1 - self.alpha)
- 计算pt和focal_weight。注意:这里pt的计算看似与原公式不一致,但实际上,这里的pt等效原文中的1-pt,所以后续是torch.pow(pt, self.gamma)而不是torch.pow(1-pt, self.gamma),详见mmdetection官方代码
pt = target * (1.0 - pred_sigmoid) + (1.0 - target) * pred_sigmoid
focal_weight = alpha_weight * torch.pow(pt, self.gamma)
- 计算多分类的二分类交叉熵bce_loss。注意:这里写作bce_loss不是写错,而是表示多分类的二分类交叉熵。
bce_loss = self.sigmoid_cross_entropy_with_logits(input, target)
这里要着重介绍sigmoid_cross_entropy_with_logits函数,与普通的多类交叉熵不同,在处理背景类+多类时,需要对传统的多类交叉熵的计算方法进行修改:
传统的多类交叉熵
t
a
r
g
e
t
∗
−
log
(
s
o
f
t
m
a
x
(
i
n
p
u
t
)
)
target * -\log(softmax(input))
target∗−log(softmax(input))
修改后的多类的二分类交叉熵
t
a
r
g
e
t
∗
−
log
(
s
i
g
m
o
i
d
(
i
n
p
u
t
)
)
+
(
1
−
t
a
r
g
e
t
)
∗
−
log
(
1
−
s
i
g
m
o
i
d
(
i
n
p
u
t
)
)
target * -\log(sigmoid(input)) + (1 - target) * -\log(1 - sigmoid(input))
target∗−log(sigmoid(input))+(1−target)∗−log(1−sigmoid(input))
这种改动的核心在于:对于当前类,实际上是一个二分类问题,当前类作为正样本,其他类全部当做负样本。这就可以很好的解释:1、使用sigmoid激活函数;2、多出的一项是将其他类作为负样本的贡献项。
此外,链接详细解释如下的计算方法有助于避免溢出,增强计算的稳定性。
def sigmoid_cross_entropy_with_logits(input: torch.Tensor, target: torch.Tensor):
""" PyTorch Implementation for tf.nn.sigmoid_cross_entropy_with_logits:
max(x, 0) - x * z + log(1 + exp(-abs(x))) in
https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
Args:
input: (B, #anchors, #classes) float tensor.
Predicted logits for each class
target: (B, #anchors, #classes) float tensor.
One-hot encoded classification targets
Returns:
loss: (B, #anchors, #classes) float tensor.
Sigmoid cross entropy loss without reduction
"""
loss = torch.clamp(input, min=0) - input * target + \
torch.log1p(torch.exp(-torch.abs(input)))
return loss
- 计算loss。
loss = focal_weight * bce_loss
CenterNet
CenterNet主要是用于单目2D或单目3D检测的网络。
源码
class FocalLoss(nn.Module):
'''nn.Module warpper for focal loss'''
def __init__(self):
super(FocalLoss, self).__init__()
self.neg_loss = _neg_loss
def forward(self, out, target):
return self.neg_loss(out, target)
def _neg_loss(pred, gt):
''' Modified focal loss. Exactly the same as CornerNet.
Runs faster and costs a little bit more memory
Arguments:
pred (batch x c x h x w)
gt_regr (batch x c x h x w)
'''
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
loss = 0
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
在CenterNet实践中对传统的FocalLoss进行了一定修改,下面进行拆解分析:
注意pred在送入计算FocalLoss前已经经过sigmoid函数而非softmax函数激活,原因是这是在处理多分类前后景情况的常规手段。
- 最终返回的是loss变量,由此进行拆解。num_pos与原公式的 N N N等效,旨在对结果进行均值归一化。
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
- 将loss展开,根据展开结果可以很清楚的得出:这里采用 γ = 2 \gamma=2 γ=2, α \alpha α则进行了特殊地修改,对于正样本 α = 1 \alpha=1 α=1,对于负样本 α = ( 1 − g t ) 4 \alpha=(1-gt)^4 α=(1−gt)4,这是为了适配gt采用了与smooth label类似的做法。
loss = - 1 / num_pos * (pos_loss + neg_loss)
pos_loss = 1 * torch.pow(1 - pred, 2) * gt.eq(1).float() * torch.log(pred)
neg_loss = torch.pow(1 - gt, 4) * torch.pow(1 - (1 - pred), 2) * gt.lt(1).float() * torch.log(1 - pred)