还看不懂 DETR 的匈牙利损失函数?4个公式教你理解

看到 DETR 的损失函数的时候,你是否有下面的疑问:

  • 公式中的 σ ∈ S N \sigma \in \mathfrak{S}_N σSN 是什么意思?
  • 公式中的 y ^ σ ( i ) \hat{y}_{\sigma(i)} y^σ(i) 的下标 σ ( i ) \sigma(i) σ(i) 又有什么含义?
  • DETR 的损失函数计算的完整流程又是怎么样的?
  • 为什么计算 box 损失的时候为什么要加上 GIOU 损失

等等问题,都可以在下面的文章中得到解答。

概述

在 DETR 中,进行梯度更新可以分成 2 步:

  1. 使用匈牙利匹配算法,根据优化函数求解集合 y y y y ^ \hat{y} y^ 的最佳匹配:集合 y ^ \hat{y} y^ 的排列 σ ^ \hat{\sigma} σ^ σ ^ = arg min ⁡ σ ∈ S N ∑ i N L match ( y i , y ^ σ ( i ) ) \hat{\sigma}=\argmin_{\sigma \in \mathfrak{S}_N }\sum_i^{N}\mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) σ^=σSNargminiNLmatch(yi,y^σ(i)) L match ( y i , y ^ σ ( i ) ) = − 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box ( b i , b ^ σ ( i ) ) \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) = -\mathbb{1}_{\{c_i\ne \varnothing\}}\hat{p}_{\sigma(i)}(c_i) + \mathbb{1}_{\{c_i\ne \varnothing\}}\mathcal{L}_{\text{box}}(b_i, \hat{b}_{\sigma(i)} ) Lmatch(yi,y^σ(i))=1{ci=}p^σ(i)(ci)+1{ci=}Lbox(bi,b^σ(i))
  2. 根据集合 y ^ \hat{y} y^ 最佳排列 σ ^ \hat{\sigma} σ^ 带入损失函数中求解损失,并进行梯度更新。 L Hungarian ( y , y ^ ) = ∑ i = 1 N [ − log ⁡ p ^ σ ^ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box ( b i , b ^ σ ^ ( i ) ) ] \mathcal{L}_{\text{Hungarian}}(y, \hat{y}) = \sum_{i=1}^N \left[ -\log\hat{p}_{\hat{\sigma}(i)}(c_i) + \mathbb{1}_{\{c_i \ne \varnothing\}}\mathcal{L}_{\text{box}}(b_i, \hat{b}_{\hat{\sigma}(i)})\right] LHungarian(y,y^)=i=1N[logp^σ^(i)(ci)+1{ci=}Lbox(bi,b^σ^(i))] L box ( b i , b ^ σ ^ ( i ) ) = λ giou L giou ( b i , b ^ σ ^ ( i ) ) + λ L1 ∣ ∣ b i − b ^ σ ^ ( i ) ∣ ∣ 1 \mathcal{L}_{\text{box}}(b_i,\hat{b}_{\hat{\sigma}(i)}) =\lambda_{\text{giou}}\mathcal{L}_{\text{giou}}(b_i,\hat{b}_{\hat{\sigma}(i)}) + \lambda_{\text{L1}}||b_i - \hat{b}_{\hat{\sigma}(i)}||_1 Lbox(bi,b^σ^(i))=λgiouLgiou(bi,b^σ^(i))+λL1∣∣bib^σ^(i)1

可以看出来,其实想要理解 DETR 的损失函数是怎么计算的,只要理解上面的 4 个公式就行了。

第一步:求最佳 σ ^ \hat{\sigma} σ^

σ ^ = arg min ⁡ σ ∈ S N ∑ i N L match ( y i , y ^ σ ( i ) ) \hat{\sigma}=\argmin_{\sigma \in \mathfrak{S}_N }\sum_i^{N}\mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) σ^=σSNargminiNLmatch(yi,y^σ(i))(还不算是损失函数,只是通过匈牙利匹配算法求解最优排列的一个优化目标函数)

  • y = { y i } i = 1 N y=\{y_i\}_{i=1}^N y={yi}i=1N:表示 N N N个 ground truth 的集合,其中 y i y_i yi是第 i i i 个 ground truth,当然实际中,集合 y y y 中的 ground truth 数量是远小于 N N N的,为了让 y y y y ^ \hat{y} y^ 两个集合大小一致,在集合 y y y 中会使用 ∅ \varnothing (no object)来对集合进行填充。
  • y ^ = { y ^ i } i = 1 N \hat{y}=\{\hat{y}_i\}_{i=1}^N y^={y^i}i=1N:表示 N N N个 预测的集合,其中 y ^ i \hat{y}_i y^i是第 i i i 个预测。
  • σ \sigma σ:是一种预测值 y ^ \hat{y} y^ 的排列方式,我们知道集合 y y y 与集合 y ^ \hat{y} y^ 要一一匹配,然后进行排列,我们把 y y y 的排列顺序固定,就只需要调整 y ^ \hat{y} y^ 的排列顺序就可以了,而 σ \sigma σ就是表示的集合 y ^ \hat{y} y^ 的某种排列方式, y ^ σ ( i ) \hat{y}_{\sigma(i)} y^σ(i) 也只是表示,在 σ \sigma σ这种排列中,第 i i i 个预测值。
  • S N \mathfrak{S}_N SN:是排列 σ \sigma σ 的集合,也是一种对称群? arg min ⁡ σ ∈ S N \argmin_{\sigma \in \mathfrak{S}_N } argminσSN表示在 S N \mathfrak{S}_N SN内存在一种集合 y ^ \hat{y} y^ 的排列 σ \sigma σ,可以使得匈牙利匹配的 cost 最低。
  • L match ( y i , y ^ σ ( i ) ) \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) Lmatch(yi,y^σ(i)):是 pair-wise matching cost ,一般是使用匈牙利算法进行计算的。

L match ( y i , y ^ σ ( i ) ) = − 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box ( b i , b ^ σ ( i ) ) \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) = -\mathbb{1}_{\{c_i\ne \varnothing\}}\hat{p}_{\sigma(i)}(c_i) + \mathbb{1}_{\{c_i\ne \varnothing\}}\mathcal{L}_{\text{box}}(b_i, \hat{b}_{\sigma(i)} ) Lmatch(yi,y^σ(i))=1{ci=}p^σ(i)(ci)+1{ci=}Lbox(bi,b^σ(i))每个 ground truth y i y_i yi都是由两部分信息组成的,类别 + 位置,也可以写成 y i = ( c i , b i ) y_i=(c_i, b_i) yi=(ci,bi),其中:

  • c i c_i ci:表示类别信息(也有可能是空 ∅ \varnothing
  • b i b_i bi:表示位置信息,是一个归一化(值都小于 1 )的向量,有 4 个值,分别表示 box 中心点的坐标和宽高。

对于预测值 y ^ σ ( i ) \hat{y}_{\sigma(i)} y^σ(i),我们将类别和位置信息定义为 y ^ σ ( i ) = ( p ^ σ ( i ) ( c i ) , b ^ σ ( i ) ) \hat{y}_{\sigma(i)}=(\hat{p}_{\sigma(i)}(c_i), \hat{b}_{\sigma(i)}) y^σ(i)=(p^σ(i)(ci),b^σ(i))

  • p ^ σ ( i ) ( c i ) \hat{p}_{\sigma(i)}(c_i) p^σ(i)(ci):我们已经知道了 ground truch 的类别信息 c i c_i ci,这个概率值是通过模型的分类器计算得出的,反映了模型对于该预测值属于类别 c i c_i ci 的确信度。

第二步:求损失

L Hungarian ( y , y ^ ) = ∑ i = 1 N [ − log ⁡ p ^ σ ^ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box ( b i , b ^ σ ^ ( i ) ) ] \mathcal{L}_{\text{Hungarian}}(y, \hat{y}) = \sum_{i=1}^N \left[ -\log\hat{p}_{\hat{\sigma}(i)}(c_i) + \mathbb{1}_{\{c_i \ne \varnothing\}}\mathcal{L}_{\text{box}}(b_i, \hat{b}_{\hat{\sigma}(i)})\right] LHungarian(y,y^)=i=1N[logp^σ^(i)(ci)+1{ci=}Lbox(bi,b^σ^(i))]论文这里还将 b ^ σ ^ ( i ) \hat{b}_{\hat{\sigma}(i)} b^σ^(i)打错成了 b ^ σ ^ ( i ) \hat{b}_{\hat{\sigma}}(i) b^σ^(i)

  • σ ^ \hat{\sigma} σ^是最优的排列,也就是使得整体 cost 最小的 y ^ \hat{y} y^ 排列。
  • log ⁡ \log log:这里存在一个问题,为什么上面的 − 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) -\mathbb{1}_{\{c_i\ne \varnothing\}}\hat{p}_{\sigma(i)}(c_i) 1{ci=}p^σ(i)(ci) 在这里就变成了 − log ⁡ p ^ σ ^ ( i ) ( c i ) -\log\hat{p}_{\hat{\sigma}(i)}(c_i) logp^σ^(i)(ci)。一个 no object ∅ \varnothing y y y) 与预测值( y ^ \hat{y} y^)的 L match ( y i , y ^ σ ( i ) ) \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) Lmatch(yi,y^σ(i))匹配代价实际上并不取决于预测值,因为 c i = ∅ c_i = \varnothing ci=的时候,因为指示函数的关系, L match ( y i , y ^ σ ( i ) ) = 0 \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) = 0 Lmatch(yi,y^σ(i))=0,也就是一个常数。在计算匹配代价(cost)的类别代价(cost)的时候,我们使用概率而不是对数概率,因为实际效果更好。
  • 上面的公式是为了求解最优的排列 σ ^ \hat{\sigma} σ^,而这里根据最优的排列 σ ^ \hat{\sigma} σ^ 来求解损失。一般来说,为了解决类间不平衡问题,会在 c i = ∅ c_i = \varnothing ci= 对数概率项前乘以 1 / 10 1/10 1/10来降低权重。

L box ( b i , b ^ σ ^ ( i ) ) = λ giou L giou ( b i , b ^ σ ^ ( i ) ) + λ L1 ∣ ∣ b i − b ^ σ ^ ( i ) ∣ ∣ 1 \mathcal{L}_{\text{box}}(b_i,\hat{b}_{\hat{\sigma}(i)}) =\lambda_{\text{giou}}\mathcal{L}_{\text{giou}}(b_i,\hat{b}_{\hat{\sigma}(i)}) + \lambda_{\text{L1}}||b_i - \hat{b}_{\hat{\sigma}(i)}||_1 Lbox(bi,b^σ^(i))=λgiouLgiou(bi,b^σ^(i))+λL1∣∣bib^σ^(i)1

  • 直接使用 L1 损失:因为 L1 损失是计算绝对值,但是在目标检测中,大目标和小目标,即便是有着相同的相对误差(relative error),其绝对误差,也就是 L1 损失值都会有很大差异,对尺度的支持比较差。
  • 所以为了缓解 L1 损失的尺度不变性比较差的问题,我们引入了 GIOU(关于 GIOU 的部分,我后续有时间再进行补充吧)

后言

还有什么疑问或者问题可以在评论区评论,我会尽可能的解答并更新文章。

  • 23
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
DETR(Detection Transformer)是一种基于Transformer的目标检测模型,其匈牙利算法是用于实现目标与预测框之间的匹配和关联的。下面是匈牙利算法的原理解析。 匈牙利算法是一种经典的图论算法,用于解决最大权(或最小权)匹配的问题。在目标检测,我们需要将预测框与真实目标进行匹配,以确定哪些预测框与目标匹配。因此,可以使用匈牙利算法来解决这个问题。 匈牙利算法的核心思想是在二分图寻找最大匹配。二分图是一种特殊的图,其的节点被分为两个不相交的部分,每个部分的节点之间没有边相连。在目标检测,我们可以将预测框和真实目标分别看作二分图的两个部分。 匈牙利算法通过寻找增广路来实现最大匹配。增广路是指一条从未匹配的节点开始,交替经过匹配边和非匹配边,最终到达另一个未匹配的节点的路径。通过寻找增广路,可以将匹配数量不断增加,直到无法寻找到新的增广路为止。 具体来说,匈牙利算法可以分为以下几个步骤: 1. 初始化:将所有预测框和真实目标都设置为未匹配状态。 2. 寻找增广路:从一个未匹配的预测框开始,依次寻找增广路,将预测框与真实目标匹配。 3. 更新匹配:将所有找到的增广路的预测框和真实目标进行匹配,并将其他未匹配的预测框和真实目标保持不变。 4. 判断是否结束:如果所有预测框都已经匹配算法结束。否则,返回第2步,继续寻找增广路。 通过这样的方式,匈牙利算法可以找到最大匹配,并将预测框与真实目标进行匹配。在DETR模型匈牙利算法被用于实现目标和预测框之间的匹配,以便进行目标检测。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值