YOLOv9(4): DFL(Distribution Focal Loss)

1. 写在前面

        DFL最早在论文Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection中被提及,在YOLOv8中就已经在使用。

DFL的作者认为,在很多时候,检测目标的边界并不是一个确切的值,而是一个分布。

2. DFL

        DFL,全程Distribution Focal Loss,很多同学一听到Focal Loss就立马想到分类,这没错,但DFL却是用在边框回归中。

        在大部分时候,我们理所当然的认为边框(L、T、R、B)是一个确切的值,从分布的角度来说,就是一个Dirac delta分布,即一个脉冲,有如下的表示。

\int_{-\infty }^{+\infty }\delta(x-y)dx=1

当我们给定标签的范围时,如y_{0}<=y<=y_{n},由上式可以继续推导出一个复原积分公式。

y\hat{}=\int_{-\infty }^{+\infty }P(x)xdx=\int_{y_{0}}^{y_{n}}P(x)xdx

OK,到这一步还是连续的,接下来我们来推出一个离散化的公式。

但是,在很多场景下,边缘本身并不是绝对的一个值,而更倾向于是一个分布,我们可以将这个广义的分布看作是P(y)。可用如下表示。

\sum_{i=0}^{n}P(y_{i})=1

为了复原估计值y\hat{},我们可以做如下表示。

y\hat{}=\sum_{i=0}^{n}P(y_{i})y_{i}

其中y_{i}是预测的边框取值。

怎么理解这里呢?具体这个怎么反应box框的长宽呢?

我们假设box的左边框距离中心距离5.6,我们使用一个长度为16的向量表达这个距离,只需要满

\sum_{i=0}^{15}P(y_{i})=1

\hat{y}=5.6=\sum_{i=0}^{15}(P(y_{i})y_{i})

特殊情况下,5.6取值在5和6之间,就是使得

P(5) + P(6) = 1,

P(5)*5 + P(6)*6 = 5.6

接下来我们定义DFL如下。

DFL(S_{i}, S_{i+1})=-((y_{i+1}-y)log(S_{i})+(y-y_{i})log(S_{i+1}))

3. 代码阅读

        在YOLOv9中,在计算预测box的Loss时,使用到了如上述定义的DFL Loss。

BboxLoss中的_df_loss

        如下代码所示,

    def _df_loss(self, pred_dist, target):

        target_left = target.to(torch.long)

        target_right = target_left + 1

        weight_left = target_right.to(torch.float) - target

        weight_right = 1 - weight_left

        loss_left = F.cross_entropy(pred_dist.view(-1, self.reg_max + 1), target_left.view(-1), reduction="none").view(target_left.shape) * weight_left

        loss_right = F.cross_entropy(pred_dist.view(-1, self.reg_max + 1), target_right.view(-1), reduction="none").view(target_left.shape) * weight_right

        return (loss_left + loss_right).mean(-1, keepdim=True)

        上述代码中,我们将时间要估计的y限制在target_left和target_right之间,-log(S_{i})

使用PyTorch的cross_entropy。

### Focal Loss 的分布情况及其在机器学习中的实现 Focal loss 是一种用于解决类别不平衡问题的损失函数,在目标检测和其他分类任务中表现出色。该损失函数通过调整难易样本权重的方式,使得模型更加关注难以分类的样本。 #### Focal Loss 数学表达式 Focal loss 可以表示为: \[ FL(p_t) = -(1-p_t)^{\gamma}\log(p_t) \] 其中 \(p_t\) 表示预测概率,对于正类而言即为 \(\hat{y}_i\);对于负类则是 \(1-\hat{y}_i\)。\(\gamma\) 控制着容易分错的样本所占的比例大小,当设置不同的 \(\gamma\) 值时可以改变 focal loss 对不同难度样本的关注程度[^1]。 #### 实现代码示例 下面是一个简单的 PyTorch 版本的 focal loss 实现方式: ```python import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, alpha=1, gamma=2, reduction=&#39;mean&#39;): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss if self.reduction == &#39;sum&#39;: return torch.sum(F_loss) elif self.reduction == &#39;mean&#39;: return torch.mean(F_loss) else: return F_loss ``` 此段代码定义了一个 `FocalLoss` 类继承自 `nn.Module` 并实现了前向传播方法。它接受输入张量(inputs)、标签张量(targets),并返回计算后的焦点损失值。参数 `\alpha`, `\gamma` 和 `reduction` 分别用来控制加权系数、聚焦因子以及最终输出形式的选择。 #### 使用场景 通常情况下,focal loss 被应用于二元或多分类问题特别是那些具有严重类别不均衡的数据集上。比如图像识别领域内的物体检测任务,因为背景像素远多于前景对象,所以采用普通的交叉熵可能会导致训练过程中忽略掉稀疏的小目标。此时引入 focal loss 就能有效缓解这一现象,提高整体性能表现。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值