CrossEntropyLoss和FocalLoss学习记录


前言

CrossEntropyLossFocalLoss的学习记录。
以往的教程多以二分类为例,而忽略了多分类情况,本文旨在弥补这一空白。
此外,在目标检测领域中,多分类情况还会被进一步推广为多分类前后景情况,这种特殊的情况,鲜有人讨论,给许多人造成了困惑。
本文会简略地回顾公式推导和二分类情况,重点放在对多分类前后景的情况讨论,并且还补充了一些SOTA网络对于多分类前后景情况的实践细节以及个人理解和讨论。


一、背景知识回顾

CrossEntropyLoss

CrossEntropyLoss公式:
L C E = 1 N ∑ i N ∑ j C − q ( i , j ) log ⁡ p ( i , j ) {L}_{CE} = \frac{1}{N} \sum_{i}^{N} \sum_{j}^{C} -q(i,j) \log p(i,j) LCE=N1iNjCq(i,j)logp(i,j)
其中, N N N表示样本个数。 C C C表示类别数。 q q q形状为 ( N , C ) (N, C) (N,C) q ( i , j ) q(i,j) q(i,j)表示样本 i i i是类别 c c c的真值概率。 p p p形状为 ( N , C ) (N, C) (N,C) p ( i , j ) p(i,j) p(i,j)表示样本 i i i是类别 c c c的预测概率,通常是经过softmax函数激活过,限制 p ( i , j ) ∈ [ 0 , 1 ] p(i, j) \in [0,1] p(i,j)[0,1],并满足 ∑ j C p ( i , j ) = 1 \sum_j^Cp(i,j)=1 jCp(i,j)=1

FocalLoss原论文解析

本小节的符号均与原论文一致,与本文其他章节的符号体系有区别,请注意加以区分
根据原文Focal Loss for Dense Object Detection,首先以二分类任务为例介绍Focal Loss,文中首先将CrossEntropy Loss定义成如下形式:注意这里没有计算多样本的均值
C E ( p , y ) = { − log ⁡ ( p ) if  y = 1 − log ⁡ ( 1 − p ) otherwise CE(p,y)=\begin{cases}-\log(p)&\text{if~}y=1\\-\log(1-p)&\text{otherwise}&\end{cases} CE(p,y)={log(p)log(1p)if y=1otherwise
其中, y ∈ { ± 1 } y\in\{\pm1\} y{±1}表示真值类别, p ∈ [ 0 , 1 ] p\in[0,1] p[0,1]表示预测为 y = 1 y=1 y=1的概率。
进一步地,为了方便符号标记,定义 p t p_t pt:
p t = { p if  y = 1 1 − p otherwise p_t=\begin{cases}p&\text{if~} y=1\\1-p&\text{otherwise}\end{cases} pt={p1pif y=1otherwise
因此,CrossEntropy Loss被简化成如下形式:
C E ( p , y ) = C E ( p t ) = − log ⁡ ( p t ) CE(p,y)=CE(p_t)=-\log(p_t) CE(p,y)=CE(pt)=log(pt)
为了平衡不同类别的权重,需要引入类别权重, α ∈ [ 0 , 1 ] \alpha \in [0, 1] α[0,1] y = 1 y=1 y=1的权重, 1 − α 1-\alpha 1α y = − 1 y=-1 y=1的权重,引入权重后的CrossEntropy Loss表示为:
C E ( p t ) = − α t log ⁡ ( p t ) CE(p_t)=-\alpha_t\log(p_t) CE(pt)=αtlog(pt)
由于 α t \alpha_t αt本质上是对正负样本对于loss的贡献平衡进行了缓解,但并未处理难易样本的贡献平衡,因此提出了Focal Loss:
F L ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) FL(p_t) = -\alpha_t (1 - p_t)^{\gamma} \log(p_t) FL(pt)=αt(1pt)γlog(pt)
其中 α t \alpha_t αt为正负样本权重调节因子, γ ≥ 0 \gamma\geq0 γ0为难易样本权重调节因子。
结合Facebook开源代码fvcore

def sigmoid_focal_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    alpha: float = -1,
    gamma: float = 2,
    reduction: str = "none",
) -> torch.Tensor:
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
        reduction: 'none' | 'mean' | 'sum'
                 'none': No reduction will be applied to the output.
                 'mean': The output will be averaged.
                 'sum': The output will be summed.
    Returns:
        Loss tensor with the reduction option applied.
    """
    inputs = inputs.float()
    targets = targets.float()
    p = torch.sigmoid(inputs)
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = p * targets + (1 - p) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    if reduction == "mean":
        loss = loss.mean()
    elif reduction == "sum":
        loss = loss.sum()

    return loss

FocalLoss推广

在分析原文和代码的基础上,我们采用统一的符号系统,重构Focal Loss的公式并推广至一般形式:
L F o c a l = 1 N ∑ i N ∑ j C − α ( c ) [ 1 − q ( i , j ) p ( i , j ) ] γ ( c ) q ( i , j ) log ⁡ p ( i , j ) {L}_{Focal} = \frac{1}{N} \sum_{i}^{N} \sum_{j}^{C} -\alpha(c)[1-q(i,j)p(i,j)]^{\gamma(c)} q(i,j) \log p(i,j) LFocal=N1iNjCα(c)[1q(i,j)p(i,j)]γ(c)q(i,j)logp(i,j)
其中, N N N表示样本个数。 C C C表示类别数。 α \alpha α形状为 ( C , ) (C,) (C,) α ( c ) \alpha(c) α(c)表示类别c的正负样本权重调节因子。 γ \gamma γ形状为 ( C , ) (C,) (C,) γ ( c ) \gamma(c) γ(c)表示类别c的难易样本权重调节因子。 q q q形状为 ( N , C ) (N, C) (N,C) q ( i , j ) q(i,j) q(i,j)表示样本 i i i是类别 c c c的真值概率。 p p p形状为 ( N , C ) (N, C) (N,C) p ( i , j ) p(i,j) p(i,j)表示样本 i i i是类别 c c c的预测概率,通常是经过softmax函数激活过,限制 p ( i , j ) ∈ [ 0 , 1 ] p(i, j) \in [0,1] p(i,j)[0,1],并满足 ∑ j C p ( i , j ) = 1 \sum_j^Cp(i,j)=1 jCp(i,j)=1


二、二分类任务中的实践细节

CrossEntropyLoss

在hard label条件下, q q q中元素仅包含0和1。可以将CrossEntropyLoss公式进行以下简化:
L C E = 1 N ∑ i N − [ y ( i ) log ⁡ x ( i ) + ( 1 − y ( i ) ) log ⁡ ( 1 − x ( i ) ) ] {L}_{CE} = \frac{1}{N} \sum_{i}^{N} -[y(i) \log x(i)+(1-y(i)) \log (1-x(i))] LCE=N1iN[y(i)logx(i)+(1y(i))log(1x(i))]
其中, y y y表示标签,形状为 ( N , ) (N,) (N,) y ( i ) y(i) y(i)表示样本 i i i的标签,在二分类中 y ( i ) = 0 y(i)=0 y(i)=0 y ( i ) = 1 y(i)=1 y(i)=1 x x x表示预测为 y ( i ) = 1 y(i)=1 y(i)=1概率,形状为 ( N , ) (N,) (N,) x ( i ) x(i) x(i)表示样本 i i i的预测概率,通常是经过sigmoid函数激活过,限制 x ( i ) ∈ [ 0 , 1 ] x(i) \in [0,1] x(i)[0,1]
接着使用代码进行实验验证,验证简化、原公式以及pytorch官方API的一致性:

import torch
from torch import nn
import torch.nn.functional as F

# 模拟二分类网络输出层结果和标签
pred = torch.randn(3)
'''
tensor([ 0.9764, -0.3312, -0.8662])
'''
target = torch.tensor([1, 0, 1])
'''
tensor([1, 0, 1])
'''

# 简化公式计算
x = pred.sigmoid()
'''
 tensor([0.7264, 0.4180, 0.2960])
'''
y = target.float()
'''
tensor([1., 0., 1.])
'''
bce1_none = -(y*torch.log(x)+(1-y)*torch.log(1-x))
'''
tensor([0.3197, 0.5412, 1.2173])
'''
bce1 = bce1_none.mean()
'''
tensor(0.6927)
'''

# 原公式计算
p = torch.stack([1-x, x], dim=1)
'''
tensor([[0.2736, 0.7264],
        [0.5820, 0.4180],
        [0.7040, 0.2960]])
'''
q = F.one_hot(y, num_classes=2)
'''
tensor([[0, 1],
        [1, 0],
        [0, 1]])
'''
bce2_none = (-q*torch.log(p)).sum(dim=1)
'''
tensor([0.3197, 0.5412, 1.2173])
'''
bce2 = bce2_none.mean()
'''
tensor(0.6927)
'''

# pytorch API: F.binary_cross_entropy
bce3_none = bce3_none = F.binary_cross_entropy(pred.sigmoid(), target.float(), reduction='none')
'''
tensor([0.3197, 0.5412, 1.2173])
'''
bce3 = bce3_none.mean()
'''
tensor(0.6927)
'''

#pytorch API: binary_cross_entropy_with_logits
bce4_none = F.binary_cross_entropy_with_logits(pred, target.float(), reduction='none')
'''
tensor([0.3197, 0.5412, 1.2173])
'''
bce4 = bce4_none.mean()
'''
tensor(0.6927)
'''

FocalLoss

在hard label条件下, q q q中元素仅包含0和1。可以将FocalLoss公式进行以下简化:
L C E = 1 N ∑ i N − α ( i ) [ 1 − p t ( i ) ] γ [ y ( i ) log ⁡ x ( i ) + ( 1 − y ( i ) ) log ⁡ ( 1 − x ( i ) ) ] α ( i ) = y ( i ) α + ( 1 − y ( i ) ) ( 1 − α ) p t ( i ) = y ( i ) x ( i ) − [ 1 − y ( i ) ] [ 1 − x ( i ) ] {L}_{CE} = \frac{1}{N} \sum_{i}^{N} -\alpha(i)[1-p_t(i)]^\gamma [y(i) \log x(i)+(1-y(i)) \log (1-x(i))] \newline \alpha(i)=y(i)\alpha+(1-y(i))(1-\alpha) \newline p_t(i)=y(i)x(i)-[1-y(i)][1-x(i)] LCE=N1iNα(i)[1pt(i)]γ[y(i)logx(i)+(1y(i))log(1x(i))]α(i)=y(i)α+(1y(i))(1α)pt(i)=y(i)x(i)[1y(i)][1x(i)]
其中, α \alpha α γ \gamma γ为常数, α \alpha α表示 y ( i ) = 1 y(i)=1 y(i)=1的正负样本权重调节因子, y y y表示标签,形状为 ( N , ) (N,) (N,) y ( i ) y(i) y(i)表示样本 i i i的标签,在二分类中 y ( i ) = 0 y(i)=0 y(i)=0 y ( i ) = 1 y(i)=1 y(i)=1 x x x表示预测为 y ( i ) = 1 y(i)=1 y(i)=1概率,形状为 ( N , ) (N,) (N,) x ( i ) x(i) x(i)表示样本 i i i的预测概率,通常是经过sigmoid函数激活过,限制 x ( i ) ∈ [ 0 , 1 ] x(i) \in [0,1] x(i)[0,1]
接着使用代码进行实验验证,验证简化、原公式以及pytorch官方API的一致性:

import torch
from torch import nn
import torch.nn.functional as F


# 模拟二分类网络输出层结果和标签
pred = torch.randn(3)
'''
tensor([ 0.9764, -0.3312, -0.8662])
'''
target = torch.tensor([1, 0, 1])
'''
tensor([1, 0, 1])
'''

# 设置超参数alpha=0.25, gamma=2.0
alpha = 0.25
gamma = 2.0

# 简化公式计算
x = pred.sigmoid()
'''
 tensor([0.7264, 0.4180, 0.2960])
'''
y = target.float()
'''
tensor([1., 0., 1.])
'''
alpha_t = y * alpha + (1 - y) * (1 - alpha)
'''
tensor([0.2500, 0.7500, 0.2500])
'''
pt = y * x + (1 - y) * (1 - x)
'''
tensor([0.7264, 0.5820, 0.2960])
'''
bce = -(y*torch.log(x)+(1-y)*torch.log(1-x))
# bce = F.binary_cross_entropy(x, y, reduction='none')
# bce = F.binary_cross_entropy_with_logits(pred, y, reduction='none')
# bce = -torch.log(pt)
'''
tensor([0.3197, 0.5412, 1.2173])
'''
focal_loss1_none = alpha_t * ((1 - pt) ** gamma) * bce
'''
tensor([0.0060, 0.0709, 0.1508])
'''
focal_loss1 = focal_loss1_none.mean()
'''
tensor(0.0759)
'''

# 原公式计算
alphas = y * alpha + (1 - y) * (1 - alpha)
'''
tensor([0.2500, 0.7500, 0.2500])
'''
gammas = y * gamma + (1 - y) * gamma
'''
tensor([2., 2., 2.])
'''
p = torch.stack([1-x, x], dim=1)
'''
tensor([[0.2736, 0.7264],
        [0.5820, 0.4180],
        [0.7040, 0.2960]])
'''
q = F.one_hot(y, num_classes=2)
'''
tensor([[0, 1],
        [1, 0],
        [0, 1]])
'''
focal_loss2_none = (-alphas.reshape(-1, 1) * ((1 - q * p) ** gammas.reshape(-1, 1)) * q * torch.log(p)).sum(dim=1)
'''
tensor([0.0060, 0.0709, 0.1508])
'''
focal_loss2 = focal_loss2_none.mean()
'''
tensor(0.0759)
'''

# torchvision API: sigmoid_focal_loss
from torchvision.ops import sigmoid_focal_loss
focal_loss3_none = sigmoid_focal_loss(pred, target.float(), alpha, gamma, reduction='none')
'''
tensor([0.0060, 0.0709, 0.1508])
'''
focal_loss3 = focal_loss3_none.mean()
'''
tensor(0.0759)
'''

# fvcore API: sigmoid_focal_loss
from fvcore.nn import sigmoid_focal_loss 
focal_loss4_none = sigmoid_focal_loss(pred, target.float(), alpha, gamma, reduction='none')
'''
tensor([0.0060, 0.0709, 0.1508])
'''
focal_loss4 = focal_loss4_none.mean()
'''
tensor(0.0759)
'''

三、多分类任务中的实践细节

多分类任务重,预测结果 p r e d pred pred的形状为 ( N , C ) (N,C) (N,C),在hard label情况下,真值标签 t a r g e t target target的形状为 ( N , ) (N,) (N,)

CrossEntropyLoss

import torch
from torch import nn
import torch.nn.functional as F

pred = torch.randn(3, 5)
'''
tensor([[ 1.0865, -0.6392,  0.0881, -0.5137,  1.4306],
        [ 0.0467, -0.4962, -1.5786,  2.0470,  0.8474],
        [-0.6024, -0.4759,  1.3952, -0.1280,  0.6135]])
'''
target = torch.tensor([1, 0, 4])
'''
tensor([1, 0, 4])
'''

# 原公式计算
p = pred.softmax(dim=1)
'''
tensor([[0.3166, 0.0564, 0.1166, 0.0639, 0.4466],
        [0.0877, 0.0510, 0.0173, 0.6486, 0.1954],
        [0.0690, 0.0783, 0.5088, 0.1109, 0.2329]])
'''
q = F.one_hot(target, num_classes=5)
'''
tensor([[0, 1, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1]])
'''
ce_loss1_none = (-q * torch.log(p)).sum(dim=1)
'''
tensor([2.8760, 2.4333, 1.4573])
'''

# pytorch API: cross_entropy
ce_loss2_none = F.cross_entropy(pred, target, reduction='none')
'''
tensor([2.8760, 2.4333, 1.4573])
'''

# pytorch API: nll_loss
ce_loss3_none = F.nll_loss(torch.log(p), target, reduction='none')
'''
tensor([2.8760, 2.4333, 1.4573])
'''

FocalLoss

import torch
from torch import nn
import torch.nn.functional as F

pred = torch.randn(3, 5)
'''
tensor([[ 1.0865, -0.6392,  0.0881, -0.5137,  1.4306],
        [ 0.0467, -0.4962, -1.5786,  2.0470,  0.8474],
        [-0.6024, -0.4759,  1.3952, -0.1280,  0.6135]])
'''
target = torch.tensor([1, 0, 4])
'''
tensor([1, 0, 4])
'''

# 设置超参数alpha=0.25, gamma=2.0
alpha = torch.tensor([0.25, 0.75, 0.75, 0.75, 0.75])
gamma = torch.tensor([2., 2., 2., 2., 2.])

# 原公式计算
p = pred.softmax(dim=1)
'''
tensor([[0.3166, 0.0564, 0.1166, 0.0639, 0.4466],
        [0.0877, 0.0510, 0.0173, 0.6486, 0.1954],
        [0.0690, 0.0783, 0.5088, 0.1109, 0.2329]])
'''
q = F.one_hot(target, num_classes=5)
'''
tensor([[0, 1, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1]])
'''
alpha_gathered = torch.gather(alpha, 0, target)
'''
tensor([0.7500, 0.2500, 0.7500])
'''
gamma_gathered = torch.gather(gamma, 0, target)
'''
tensor([2., 2., 2.])
'''
focal_loss1_none = (-alpha_gathered.reshape(-1, 1) * ((1 - (q * p)) ** gamma_gathered.reshape(-1, 1)) * (q * torch.log(p))).sum(dim=1)
'''
tensor([1.9208, 0.5063, 0.6432])
'''

# 目前未有官方API,参考https://github.com/yatengLG/Focal-Loss-Pytorch/blob/master/Focal_Loss.py
class focal_loss(nn.Module):
    def __init__(self, alpha=None, gamma=2, num_classes = 3, size_average=True):
        """
        focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)
        步骤详细的实现了 focal_loss损失函数.
        :param alpha:   阿尔法α,类别权重. 当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.25
        :param gamma:   伽马γ,难易样本调节参数. retainnet中设置为2
        :param num_classes:     类别数量
        :param size_average:    损失计算方式,默认取均值
        """
        super(focal_loss,self).__init__()
        self.size_average = size_average
        if alpha is None:
            self.alpha = torch.ones(num_classes)
        elif isinstance(alpha,list):
            assert len(alpha)==num_classes   # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
            self.alpha = torch.Tensor(alpha)
        else:
            assert alpha<1   #如果α为一个常数,则降低第一类的影响,在目标检测中第一类为背景类
            self.alpha = torch.zeros(num_classes)
            self.alpha[0] += alpha
            self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]

        self.gamma = gamma
        
        print('Focal Loss:')
        print('    Alpha = {}'.format(self.alpha))
        print('    Gamma = {}'.format(self.gamma))
        
    def forward(self, preds, labels):
        """
        focal_loss损失计算
        :param preds:   预测类别. size:[B,N,C] or [B,C]    分别对应与检测与分类任务, B 批次, N检测框数, C类别数
        :param labels:  实际类别. size:[B,N] or [B]
        :return:
        """
        # assert preds.dim()==2 and labels.dim()==1
        preds = preds.view(-1,preds.size(-1))
        alpha = self.alpha.to(preds.device)
        preds_logsoft = F.log_softmax(preds, dim=1) # log_softmax
        preds_softmax = torch.exp(preds_logsoft)    # softmax

        preds_softmax = preds_softmax.gather(1,labels.view(-1,1))   # 这部分实现nll_loss ( crossempty = log_softmax + nll )
        preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
        alpha = self.alpha.gather(0,labels.view(-1))
        loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft)  # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ

        loss = torch.mul(alpha, loss.t())
        if self.size_average:
            loss = loss.mean()
        return loss

focal_loss2_none = focal_loss(alpha=0.25, gamma=2, num_classes=5, size_average=False)(pred, target)
'''
tensor([[1.9208, 0.5063, 0.6432]])
'''

四、SOTA网络中的多分类前后景

VirConv

VirConv是一个基于相机和激光雷达的3D目标检测网络,采用伪点思想。
源码

class SigmoidFocalClassificationLoss(nn.Module):
    """
    Sigmoid focal cross entropy loss.
    """

    def __init__(self, gamma: float = 2.0, alpha: float = 0.25):
        """
        Args:
            gamma: Weighting parameter to balance loss for hard and easy examples.
            alpha: Weighting parameter to balance loss for positive and negative examples.
        """
        super(SigmoidFocalClassificationLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    @staticmethod
    def sigmoid_cross_entropy_with_logits(input: torch.Tensor, target: torch.Tensor):
        """ PyTorch Implementation for tf.nn.sigmoid_cross_entropy_with_logits:
            max(x, 0) - x * z + log(1 + exp(-abs(x))) in
            https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits

        Args:
            input: (B, #anchors, #classes) float tensor.
                Predicted logits for each class
            target: (B, #anchors, #classes) float tensor.
                One-hot encoded classification targets

        Returns:
            loss: (B, #anchors, #classes) float tensor.
                Sigmoid cross entropy loss without reduction
        """
        loss = torch.clamp(input, min=0) - input * target + \
               torch.log1p(torch.exp(-torch.abs(input)))
        return loss

    def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor):
        """
        Args:
            input: (B, #anchors, #classes) float tensor.
                Predicted logits for each class
            target: (B, #anchors, #classes) float tensor.
                One-hot encoded classification targets
            weights: (B, #anchors) float tensor.
                Anchor-wise weights.

        Returns:
            weighted_loss: (B, #anchors, #classes) float tensor after weighting.
        """
        pred_sigmoid = torch.sigmoid(input)
        alpha_weight = target * self.alpha + (1 - target) * (1 - self.alpha)
        pt = target * (1.0 - pred_sigmoid) + (1.0 - target) * pred_sigmoid
        focal_weight = alpha_weight * torch.pow(pt, self.gamma)

        bce_loss = self.sigmoid_cross_entropy_with_logits(input, target)

        loss = focal_weight * bce_loss

        if weights.shape.__len__() == 2 or \
                (weights.shape.__len__() == 1 and target.shape.__len__() == 2):
            weights = weights.unsqueeze(-1)

        assert weights.shape.__len__() == loss.shape.__len__()

        return loss * weights

在VirConv的实践过程中,会面临背景类+多类的情况,会令人疑惑的是背景类的标签为 0 0 0,多类标签为 1 , . . . , C 1, ..., C 1,...,C,而预测类别的head所输出的张量维度是 C C C,真值标签中包含 0 0 0,此时如何使用FocalLoss计算损失?下面进行逐步分析:

  1. VirConv-V使用VoxelRCNN的框架,一阶段网络使用一张形状为 ( B , 256 , H , W ) (B, 256, H, W) (B,256,H,W)的BEV视角下的特征图,并结合预设的anchor,提出roi。
  2. 在经过预测roi类别的head后,得到形状为 ( B , N , C ) (B, N, C) (B,N,C)的张量,也就是上面代码中的input
  3. 上面代码中的target表示真值的onehot向量,生成方法如下面代码所示,其中cls_targets表示真值标签包含 0 , 1 , . . . , C 0, 1, ..., C 0,1,...,C,然后利用真值标签生成onehot向量,形状为 ( B , N , C + 1 ) (B, N, C+1) (B,N,C+1),值得注意的是:后续会将onehot向量的第0维抛弃,从而形状变为 ( B , N , C ) (B, N, C) (B,N,C)
cls_targets = box_cls_labels * cared.type_as(box_cls_labels)
cls_targets = cls_targets.unsqueeze(dim=-1)
cls_targets = cls_targets.squeeze(dim=-1)
one_hot_targets = torch.zeros(
    *list(cls_targets.shape), self.num_class + 1, dtype=cls_preds.dtype, device=cls_targets.device
)
one_hot_targets.scatter_(-1, cls_targets.unsqueeze(dim=-1).long(), 1.0)
cls_preds = cls_preds.view(batch_size, -1, self.num_class)
one_hot_targets = one_hot_targets[..., 1:]
  1. 在进入FocalLoss的计算后,首先会将input用sigmoid函数激活。
  2. 计算alpha_weight与原公式一致。
alpha_weight = target * self.alpha + (1 - target) * (1 - self.alpha)
  1. 计算ptfocal_weight。注意:这里pt的计算看似与原公式不一致,但实际上,这里的pt等效原文中的1-pt,所以后续是torch.pow(pt, self.gamma)而不是torch.pow(1-pt, self.gamma),详见mmdetection官方代码
pt = target * (1.0 - pred_sigmoid) + (1.0 - target) * pred_sigmoid
focal_weight = alpha_weight * torch.pow(pt, self.gamma)
  1. 计算多分类的二分类交叉熵bce_loss。注意:这里写作bce_loss不是写错,而是表示多分类的二分类交叉熵。
bce_loss = self.sigmoid_cross_entropy_with_logits(input, target)

这里要着重介绍sigmoid_cross_entropy_with_logits函数,与普通的多类交叉熵不同,在处理背景类+多类时,需要对传统的多类交叉熵的计算方法进行修改:
传统的多类交叉熵
t a r g e t ∗ − log ⁡ ( s o f t m a x ( i n p u t ) ) target * -\log(softmax(input)) targetlog(softmax(input))
修改后的多类的二分类交叉熵
t a r g e t ∗ − log ⁡ ( s i g m o i d ( i n p u t ) ) + ( 1 − t a r g e t ) ∗ − log ⁡ ( 1 − s i g m o i d ( i n p u t ) ) target * -\log(sigmoid(input)) + (1 - target) * -\log(1 - sigmoid(input)) targetlog(sigmoid(input))+(1target)log(1sigmoid(input))
这种改动的核心在于:对于当前类,实际上是一个二分类问题,当前类作为正样本,其他类全部当做负样本。这就可以很好的解释:1、使用sigmoid激活函数;2、多出的一项是将其他类作为负样本的贡献项。

此外,链接详细解释如下的计算方法有助于避免溢出,增强计算的稳定性。

def sigmoid_cross_entropy_with_logits(input: torch.Tensor, target: torch.Tensor):
        """ PyTorch Implementation for tf.nn.sigmoid_cross_entropy_with_logits:
            max(x, 0) - x * z + log(1 + exp(-abs(x))) in
            https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits

        Args:
            input: (B, #anchors, #classes) float tensor.
                Predicted logits for each class
            target: (B, #anchors, #classes) float tensor.
                One-hot encoded classification targets

        Returns:
            loss: (B, #anchors, #classes) float tensor.
                Sigmoid cross entropy loss without reduction
        """
        loss = torch.clamp(input, min=0) - input * target + \
               torch.log1p(torch.exp(-torch.abs(input)))
        return loss
  1. 计算loss
loss = focal_weight * bce_loss

CenterNet

CenterNet主要是用于单目2D或单目3D检测的网络。
源码

class FocalLoss(nn.Module):
  '''nn.Module warpper for focal loss'''
  def __init__(self):
    super(FocalLoss, self).__init__()
    self.neg_loss = _neg_loss

  def forward(self, out, target):
    return self.neg_loss(out, target)

def _neg_loss(pred, gt):
  ''' Modified focal loss. Exactly the same as CornerNet.
      Runs faster and costs a little bit more memory
    Arguments:
      pred (batch x c x h x w)
      gt_regr (batch x c x h x w)
  '''
  pos_inds = gt.eq(1).float()
  neg_inds = gt.lt(1).float()

  neg_weights = torch.pow(1 - gt, 4)

  loss = 0

  pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
  neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds

  num_pos  = pos_inds.float().sum()
  pos_loss = pos_loss.sum()
  neg_loss = neg_loss.sum()

  if num_pos == 0:
    loss = loss - neg_loss
  else:
    loss = loss - (pos_loss + neg_loss) / num_pos
  return loss

在CenterNet实践中对传统的FocalLoss进行了一定修改,下面进行拆解分析:
注意pred在送入计算FocalLoss前已经经过sigmoid函数而非softmax函数激活,原因是这是在处理多分类前后景情况的常规手段。

  1. 最终返回的是loss变量,由此进行拆解。num_pos与原公式的 N N N等效,旨在对结果进行均值归一化。
if num_pos == 0:
   loss = loss - neg_loss
 else:
   loss = loss - (pos_loss + neg_loss) / num_pos
 return loss
  1. loss展开,根据展开结果可以很清楚的得出:这里采用 γ = 2 \gamma=2 γ=2 α \alpha α则进行了特殊地修改,对于正样本 α = 1 \alpha=1 α=1,对于负样本 α = ( 1 − g t ) 4 \alpha=(1-gt)^4 α=(1gt)4,这是为了适配gt采用了与smooth label类似的做法。
loss = - 1 / num_pos * (pos_loss + neg_loss)
pos_loss = 1 * torch.pow(1 - pred, 2) * gt.eq(1).float() * torch.log(pred)
neg_loss = torch.pow(1 - gt, 4) * torch.pow(1 - (1 - pred), 2) * gt.lt(1).float() * torch.log(1 - pred)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值