还看不懂 DETR 的匈牙利损失函数?4个公式教你理解

看到 DETR 的损失函数的时候,你是否有下面的疑问:

  • 公式中的 σ ∈ S N \sigma \in \mathfrak{S}_N σSN 是什么意思?
  • 公式中的 y ^ σ ( i ) \hat{y}_{\sigma(i)} y^σ(i) 的下标 σ ( i ) \sigma(i) σ(i) 又有什么含义?
  • DETR 的损失函数计算的完整流程又是怎么样的?
  • 为什么计算 box 损失的时候为什么要加上 GIOU 损失

等等问题,都可以在下面的文章中得到解答。

概述

在 DETR 中,进行梯度更新可以分成 2 步:

  1. 使用匈牙利匹配算法,根据优化函数求解集合 y y y y ^ \hat{y} y^ 的最佳匹配:集合 y ^ \hat{y} y^ 的排列 σ ^ \hat{\sigma} σ^ σ ^ = arg min ⁡ σ ∈ S N ∑ i N L match ( y i , y ^ σ ( i ) ) \hat{\sigma}=\argmin_{\sigma \in \mathfrak{S}_N }\sum_i^{N}\mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) σ^=σSNargminiNLmatch(yi,y^σ(i)) L match ( y i , y ^ σ ( i ) ) = − 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box ( b i , b ^ σ ( i ) ) \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) = -\mathbb{1}_{\{c_i\ne \varnothing\}}\hat{p}_{\sigma(i)}(c_i) + \mathbb{1}_{\{c_i\ne \varnothing\}}\mathcal{L}_{\text{box}}(b_i, \hat{b}_{\sigma(i)} ) Lmatch(yi,y^σ(i))=1{ci=}p^σ(i)(ci)+1{ci=}Lbox(bi,b^σ(i))
  2. 根据集合 y ^ \hat{y} y^ 最佳排列 σ ^ \hat{\sigma} σ^ 带入损失函数中求解损失,并进行梯度更新。 L Hungarian ( y , y ^ ) = ∑ i = 1 N [ − log ⁡ p ^ σ ^ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box ( b i , b ^ σ ^ ( i ) ) ] \mathcal{L}_{\text{Hungarian}}(y, \hat{y}) = \sum_{i=1}^N \left[ -\log\hat{p}_{\hat{\sigma}(i)}(c_i) + \mathbb{1}_{\{c_i \ne \varnothing\}}\mathcal{L}_{\text{box}}(b_i, \hat{b}_{\hat{\sigma}(i)})\right] LHungarian(y,y^)=i=1N[logp^σ^(i)(ci)+1{ci=}Lbox(bi,b^σ^(i))] L box ( b i , b ^ σ ^ ( i ) ) = λ giou L giou ( b i , b ^ σ ^ ( i ) ) + λ L1 ∣ ∣ b i − b ^ σ ^ ( i ) ∣ ∣ 1 \mathcal{L}_{\text{box}}(b_i,\hat{b}_{\hat{\sigma}(i)}) =\lambda_{\text{giou}}\mathcal{L}_{\text{giou}}(b_i,\hat{b}_{\hat{\sigma}(i)}) + \lambda_{\text{L1}}||b_i - \hat{b}_{\hat{\sigma}(i)}||_1 Lbox(bi,b^σ^(i))=λgiouLgiou(bi,b^σ^(i))+λL1∣∣bib^σ^(i)1

可以看出来,其实想要理解 DETR 的损失函数是怎么计算的,只要理解上面的 4 个公式就行了。

第一步:求最佳 σ ^ \hat{\sigma} σ^

σ ^ = arg min ⁡ σ ∈ S N ∑ i N L match ( y i , y ^ σ ( i ) ) \hat{\sigma}=\argmin_{\sigma \in \mathfrak{S}_N }\sum_i^{N}\mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) σ^=σSNargminiNLmatch(yi,y^σ(i))(还不算是损失函数,只是通过匈牙利匹配算法求解最优排列的一个优化目标函数)

  • y = { y i } i = 1 N y=\{y_i\}_{i=1}^N y={yi}i=1N:表示 N N N个 ground truth 的集合,其中 y i y_i yi是第 i i i 个 ground truth,当然实际中,集合 y y y 中的 ground truth 数量是远小于 N N N的,为了让 y y y y ^ \hat{y} y^ 两个集合大小一致,在集合 y y y 中会使用 ∅ \varnothing (no object)来对集合进行填充。
  • y ^ = { y ^ i } i = 1 N \hat{y}=\{\hat{y}_i\}_{i=1}^N y^={y^i}i=1N:表示 N N N个 预测的集合,其中 y ^ i \hat{y}_i y^i是第 i i i 个预测。
  • σ \sigma σ:是一种预测值 y ^ \hat{y} y^ 的排列方式,我们知道集合 y y y 与集合 y ^ \hat{y} y^ 要一一匹配,然后进行排列,我们把 y y y 的排列顺序固定,就只需要调整 y ^ \hat{y} y^ 的排列顺序就可以了,而 σ \sigma σ就是表示的集合 y ^ \hat{y} y^ 的某种排列方式, y ^ σ ( i ) \hat{y}_{\sigma(i)} y^σ(i) 也只是表示,在 σ \sigma σ这种排列中,第 i i i 个预测值。
  • S N \mathfrak{S}_N SN:是排列 σ \sigma σ 的集合,也是一种对称群? arg min ⁡ σ ∈ S N \argmin_{\sigma \in \mathfrak{S}_N } argminσSN表示在 S N \mathfrak{S}_N SN内存在一种集合 y ^ \hat{y} y^ 的排列 σ \sigma σ,可以使得匈牙利匹配的 cost 最低。
  • L match ( y i , y ^ σ ( i ) ) \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) Lmatch(yi,y^σ(i)):是 pair-wise matching cost ,一般是使用匈牙利算法进行计算的。

L match ( y i , y ^ σ ( i ) ) = − 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box ( b i , b ^ σ ( i ) ) \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) = -\mathbb{1}_{\{c_i\ne \varnothing\}}\hat{p}_{\sigma(i)}(c_i) + \mathbb{1}_{\{c_i\ne \varnothing\}}\mathcal{L}_{\text{box}}(b_i, \hat{b}_{\sigma(i)} ) Lmatch(yi,y^σ(i))=1{ci=}p^σ(i)(ci)+1{ci=}Lbox(bi,b^σ(i))每个 ground truth y i y_i yi都是由两部分信息组成的,类别 + 位置,也可以写成 y i = ( c i , b i ) y_i=(c_i, b_i) yi=(ci,bi),其中:

  • c i c_i ci:表示类别信息(也有可能是空 ∅ \varnothing
  • b i b_i bi:表示位置信息,是一个归一化(值都小于 1 )的向量,有 4 个值,分别表示 box 中心点的坐标和宽高。

对于预测值 y ^ σ ( i ) \hat{y}_{\sigma(i)} y^σ(i),我们将类别和位置信息定义为 y ^ σ ( i ) = ( p ^ σ ( i ) ( c i ) , b ^ σ ( i ) ) \hat{y}_{\sigma(i)}=(\hat{p}_{\sigma(i)}(c_i), \hat{b}_{\sigma(i)}) y^σ(i)=(p^σ(i)(ci),b^σ(i))

  • p ^ σ ( i ) ( c i ) \hat{p}_{\sigma(i)}(c_i) p^σ(i)(ci):我们已经知道了 ground truch 的类别信息 c i c_i ci,这个概率值是通过模型的分类器计算得出的,反映了模型对于该预测值属于类别 c i c_i ci 的确信度。

第二步:求损失

L Hungarian ( y , y ^ ) = ∑ i = 1 N [ − log ⁡ p ^ σ ^ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box ( b i , b ^ σ ^ ( i ) ) ] \mathcal{L}_{\text{Hungarian}}(y, \hat{y}) = \sum_{i=1}^N \left[ -\log\hat{p}_{\hat{\sigma}(i)}(c_i) + \mathbb{1}_{\{c_i \ne \varnothing\}}\mathcal{L}_{\text{box}}(b_i, \hat{b}_{\hat{\sigma}(i)})\right] LHungarian(y,y^)=i=1N[logp^σ^(i)(ci)+1{ci=}Lbox(bi,b^σ^(i))]论文这里还将 b ^ σ ^ ( i ) \hat{b}_{\hat{\sigma}(i)} b^σ^(i)打错成了 b ^ σ ^ ( i ) \hat{b}_{\hat{\sigma}}(i) b^σ^(i)

  • σ ^ \hat{\sigma} σ^是最优的排列,也就是使得整体 cost 最小的 y ^ \hat{y} y^ 排列。
  • log ⁡ \log log:这里存在一个问题,为什么上面的 − 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) -\mathbb{1}_{\{c_i\ne \varnothing\}}\hat{p}_{\sigma(i)}(c_i) 1{ci=}p^σ(i)(ci) 在这里就变成了 − log ⁡ p ^ σ ^ ( i ) ( c i ) -\log\hat{p}_{\hat{\sigma}(i)}(c_i) logp^σ^(i)(ci)。一个 no object ∅ \varnothing y y y) 与预测值( y ^ \hat{y} y^)的 L match ( y i , y ^ σ ( i ) ) \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) Lmatch(yi,y^σ(i))匹配代价实际上并不取决于预测值,因为 c i = ∅ c_i = \varnothing ci=的时候,因为指示函数的关系, L match ( y i , y ^ σ ( i ) ) = 0 \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) = 0 Lmatch(yi,y^σ(i))=0,也就是一个常数。在计算匹配代价(cost)的类别代价(cost)的时候,我们使用概率而不是对数概率,因为实际效果更好。
  • 上面的公式是为了求解最优的排列 σ ^ \hat{\sigma} σ^,而这里根据最优的排列 σ ^ \hat{\sigma} σ^ 来求解损失。一般来说,为了解决类间不平衡问题,会在 c i = ∅ c_i = \varnothing ci= 对数概率项前乘以 1 / 10 1/10 1/10来降低权重。

L box ( b i , b ^ σ ^ ( i ) ) = λ giou L giou ( b i , b ^ σ ^ ( i ) ) + λ L1 ∣ ∣ b i − b ^ σ ^ ( i ) ∣ ∣ 1 \mathcal{L}_{\text{box}}(b_i,\hat{b}_{\hat{\sigma}(i)}) =\lambda_{\text{giou}}\mathcal{L}_{\text{giou}}(b_i,\hat{b}_{\hat{\sigma}(i)}) + \lambda_{\text{L1}}||b_i - \hat{b}_{\hat{\sigma}(i)}||_1 Lbox(bi,b^σ^(i))=λgiouLgiou(bi,b^σ^(i))+λL1∣∣bib^σ^(i)1

  • 直接使用 L1 损失:因为 L1 损失是计算绝对值,但是在目标检测中,大目标和小目标,即便是有着相同的相对误差(relative error),其绝对误差,也就是 L1 损失值都会有很大差异,对尺度的支持比较差。
  • 所以为了缓解 L1 损失的尺度不变性比较差的问题,我们引入了 GIOU(关于 GIOU 的部分,我后续有时间再进行补充吧)

后言

还有什么疑问或者问题可以在评论区评论,我会尽可能的解答并更新文章。

<think>好的,用户想知道如何在DETR模型中打印损失函数。首先,我需要回忆一下DETR模型的结构和训练流程。DETR(Detection Transformer)是基于Transformer的目标检测模型,它使用匈牙利算法来匹配预测和真实框,并计算损失。 用户的问题具体是如何在代码中实现打印损失函数。我需要考虑DETR损失函数通常包括分类损失、边界框回归损失以及可能的其他辅助损失。在训练过程中,这些损失会被计算并汇总。 接下来,我需要参考用户提供的引用内容。引用[2]提到了Faster R-CNN和DETR的比较,其中DETR的预测结果显示在终端。这可能涉及到在训练循环中添加打印语句,或者在模型输出时捕获损失值。 在PyTorch中,训练循环通常包括前向传播、计算损失、反向传播和优化步骤。因此,用户可能需要在计算损失后,使用print语句输出损失值。例如,在每个batch或每个epoch后打印损失。 另外,引用[3]提到OpenVINO的INT8量化,但这里可能不直接相关。不过,用户可能需要确保在模型推理或训练时正确获取损失值,这需要检查模型的前向传播是否返回损失,或者在自定义的训练循环中显式计算损失。 可能的步骤包括:1. 确定模型的前向传播是否返回损失;2. 在训练循环中获取损失值;3. 使用打印语句输出损失。同时,需要注意DETR的损失计算可能涉及多个部分,需要分别打印或汇总。 需要检查DETR的官方实现或相关代码库,看看损失是如何计算和返回的。例如,在DETR的代码中,可能存在一个包含各项损失的字典,或者直接返回总损失。用户需要根据具体实现来访问这些值。 总结来说,用户需要找到计算损失的位置,并在训练循环中插入打印语句。可能还需要了解DETR损失函数结构,以便正确提取各个损失部分的值。</think>在DETR模型中打印损失函数,需要结合模型训练流程,通过代码捕获并输出损失值。以下是具体实现方式的分步说明: 1. **定位损失计算模块** DETR损失函数包含分类损失$L_{cls}$和边界框损失$L_{box}$,采用匈牙利匹配算法对齐预测与真实标签。核心代码通常位于`matcher`模块和`criterion`模块中[^2]。 2. **在训练循环中获取损失** 在PyTorch训练循环中,前向传播后会调用`criterion`计算损失: ```python outputs = model(samples) # 前向传播 loss_dict = criterion(outputs, targets) # 计算损失字典 weight_dict = criterion.weight_dict # 获取损失权重 total_loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()) # 加权总损失 ``` 3. **打印损失值** 在反向传播前添加打印语句: ```python print(f"分类损失: {loss_dict['loss_ce'].item():.4f}, 边界框损失: {loss_dict['loss_bbox'].item():.4f}") print(f"总损失: {total_loss.item():.4f}") ``` 4. **完整代码片段示例** ```python for epoch in range(epochs): for batch_idx, (images, targets) in enumerate(dataloader): outputs = model(images) loss_dict = criterion(outputs, targets) total_loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()) optimizer.zero_grad() total_loss.backward() optimizer.step() # 打印每个batch的损失 if batch_idx % 10 == 0: print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}], " f"分类损失: {loss_dict['loss_ce'].item():.4f}, " f"GIoU损失: {loss_dict['loss_giou'].item():.4f}, " f"总损失: {total_loss.item():.4f}") ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值