匈牙利算法详解 - DETR中的应用
算法简介
匈牙利算法(Hungarian Algorithm)是一种解决二分图最大匹配问题的经典算法。在目标检测领域,特别是DETR(DEtection TRansformer)中,匈牙利算法被用于解决预测框和真实框之间的最优匹配问题。
在DETR中的应用
DETR使用匈牙利算法来解决以下问题:
- 一对一匹配:每个预测框只能匹配一个真实框,反之亦然
- 最小化代价:通过最小化总代价来找到最优匹配
- 端到端训练:匹配过程是可微的,可以直接用于深度学习模型训练
算法原理
在DETR中,代价矩阵由三部分组成:
- 类别代价:
cost_class = -pred_logits[:, gt_labels]
这里使用负对数概率作为代价,预测概率越大,代价越小。
数学表达式:
L
class
(
i
,
j
)
=
−
log
(
p
^
i
(
c
j
)
)
\mathcal{L}_{\text{class}}(i,j) = -\log(\hat{p}_i(c_j))
Lclass(i,j)=−log(p^i(cj))
其中:
- p ^ i ( c j ) \hat{p}_i(c_j) p^i(cj) 是第i个预测框对第j个真实框类别的预测概率
- c j c_j cj 是第j个真实框的类别标签
- 掩码代价(对于分割任务):
# 计算二值交叉熵
pos_cost = -(pred_masks.sigmoid() * target_masks)
neg_cost = -((1 - pred_masks.sigmoid()) * (1 - target_masks))
cost_mask = pos_cost.mean() + neg_cost.mean()
数学表达式:
L
mask
(
i
,
j
)
=
1
N
∑
k
=
1
N
[
−
m
i
k
log
(
m
^
i
k
)
−
(
1
−
m
i
k
)
log
(
1
−
m
^
i
k
)
]
\mathcal{L}_{\text{mask}}(i,j) = \frac{1}{N} \sum_{k=1}^N [-m_{ik}\log(\hat{m}_{ik}) - (1-m_{ik})\log(1-\hat{m}_{ik})]
Lmask(i,j)=N1k=1∑N[−miklog(m^ik)−(1−mik)log(1−m^ik)]
其中:
- m ^ i k \hat{m}_{ik} m^ik 是第i个预测掩码在像素k处的sigmoid输出
- m i k m_{ik} mik 是第i个真实掩码在像素k处的二值标签
- N N N 是掩码的像素总数
- Dice代价(对于分割任务):
numerator = 2 * (pred_masks.sigmoid() * target_masks).sum()
denominator = pred_masks.sigmoid().sum() + target_masks.sum()
cost_dice = 1 - (numerator + 1) / (denominator + 1)
数学表达式:
L
dice
(
i
,
j
)
=
1
−
2
∣
M
^
i
∩
M
j
∣
+
1
∣
M
^
i
∣
+
∣
M
j
∣
+
1
\mathcal{L}_{\text{dice}}(i,j) = 1 - \frac{2|\hat{M}_i \cap M_j| + 1}{|\hat{M}_i| + |M_j| + 1}
Ldice(i,j)=1−∣M^i∣+∣Mj∣+12∣M^i∩Mj∣+1
其中:
- M ^ i \hat{M}_i M^i 是第i个预测掩码的sigmoid输出
- M j M_j Mj 是第j个真实掩码
- ∣ ⋅ ∣ |\cdot| ∣⋅∣ 表示掩码中像素值的总和
- 分子中的交集操作通过逐像素相乘实现
- 分母加1是为了数值稳定性
最终的代价矩阵由这三部分加权组合:
C
(
i
,
j
)
=
λ
1
L
class
(
i
,
j
)
+
λ
2
L
mask
(
i
,
j
)
+
λ
3
L
dice
(
i
,
j
)
\mathcal{C}(i,j) = \lambda_1\mathcal{L}_{\text{class}}(i,j) + \lambda_2\mathcal{L}_{\text{mask}}(i,j) + \lambda_3\mathcal{L}_{\text{dice}}(i,j)
C(i,j)=λ1Lclass(i,j)+λ2Lmask(i,j)+λ3Ldice(i,j)
其中
λ
1
\lambda_1
λ1,
λ
2
\lambda_2
λ2,
λ
3
\lambda_3
λ3 是各部分代价的权重系数。
最优匹配求解
使用scipy的linear_sum_assignment实现匈牙利算法:
from scipy.optimize import linear_sum_assignment
row_ind, col_ind = linear_sum_assignment(cost_matrix)
算法工作原理
在DETR中,代价矩阵的维度关系如下:
- n:预测框的数量(通常固定为查询的数量,如300)
- m:真实框的数量(每张图片中的目标数量,可变)
具体来说:
-
预测端(n):
- DETR使用固定数量的object queries(如300个)
- 每个query预测一个可能的目标
- 这些预测包含类别、边界框和掩码(如果有)
-
真实标签端(m):
- 每张图片中实际目标的数量
- 通常远小于预测数量(如5-10个目标)
- 包含真实的类别、边界框和掩码标注
-
代价矩阵 C:
- 维度为 n×m(如300×5)
- C[i,j]表示第i个预测匹配第j个真实目标的代价
- 最终只会匹配m个预测,其余预测应该预测"无目标"类别
这种设计的优势:
- 固定数量的查询简化了网络结构
- 过量的预测提供了充分的候选集
- 通过匈牙利算法自动选择最优的m个预测
- 剩余的预测通过分类损失学习预测"无目标"
代码实现
以下是核心实现及分析:
class HungarianAssigner:
def __init__(self, match_cost_class=1, match_cost_mask=1, match_cost_dice=1):
"""初始化匈牙利分配器
Args:
match_cost_class: 类别代价权重
match_cost_mask: 掩码代价权重
match_cost_dice: Dice代价权重
"""
self.match_cost_class = match_cost_class
self.match_cost_mask = match_cost_mask
self.match_cost_dice = match_cost_dice
def compute_mask_cost(self, pred_masks, gt_masks):
"""计算掩码代价矩阵
Args:
pred_masks: [B, Q, H, W] 预测掩码
gt_masks: [B, G, H, W] 真实掩码
Returns:
mask_cost: [B, Q, G] 掩码代价矩阵
"""
B, Q = pred_masks.shape[:2]
_, G = gt_masks.shape[:2]
# 展开并计算概率
pred_masks = pred_masks.reshape(B, Q, -1).sigmoid()
gt_masks = gt_masks.reshape(B, G, -1)
# 广播计算
pred_masks = pred_masks.unsqueeze(2) # [B, Q, 1, H*W]
gt_masks = gt_masks.unsqueeze(1) # [B, 1, G, H*W]
# 计算L1距离
mask_cost = (pred_masks - gt_masks).abs().mean(dim=-1)
return mask_cost
def compute_dice_cost(self, pred_masks, gt_masks):
"""计算Dice代价矩阵
Args:
pred_masks: [B, Q, H, W] 预测掩码
gt_masks: [B, G, H, W] 真实掩码
Returns:
dice_cost: [B, Q, G] Dice代价矩阵
"""
B, Q = pred_masks.shape[:2]
_, G = gt_masks.shape[:2]
# 展开并计算概率
pred_masks = pred_masks.reshape(B, Q, -1).sigmoid()
gt_masks = gt_masks.reshape(B, G, -1)
# 广播计算
pred_masks = pred_masks.unsqueeze(2)
gt_masks = gt_masks.unsqueeze(1)
# 计算Dice系数
intersection = (pred_masks * gt_masks).sum(dim=-1)
union = pred_masks.sum(dim=-1) + gt_masks.sum(dim=-1)
dice = (2 * intersection + 1e-6) / (union + 1e-6)
return 1 - dice
def __call__(self, pred_logits, pred_masks, gt_labels, gt_masks):
"""执行匹配
Args:
pred_logits: [B, Q, C+1] 预测类别logits
pred_masks: [B, Q, H, W] 预测掩码
gt_labels: [B, G] 真实标签
gt_masks: [B, G, H, W] 真实掩码
Returns:
indices: List[Tuple] 匹配索引对
"""
B = pred_logits.shape[0]
indices = []
for b in range(B):
# 计算类别代价
cost_class = -pred_logits[b, :, gt_labels[b]]
# 计算掩码和Dice代价
cost_mask = self.compute_mask_cost(
pred_masks[b:b+1], gt_masks[b:b+1])[0]
cost_dice = self.compute_dice_cost(
pred_masks[b:b+1], gt_masks[b:b+1])[0]
# 组合代价矩阵
cost_matrix = (
self.match_cost_class * cost_class +
self.match_cost_mask * cost_mask +
self.match_cost_dice * cost_dice
)
# 执行匈牙利算法
pred_ids, gt_ids = linear_sum_assignment(
cost_matrix.detach().cpu().numpy())
indices.append((pred_ids, gt_ids))
return indices
测试结果分析
目标检测场景测试
测试数据规模:
- Batch size: 2
- 预测掩码: [2, 4, 32, 32] (每个batch 4个预测)
- 真实掩码: [2, 3, 32, 32] (每个batch 3个目标)
Batch 0 匹配结果
预测框 | 真实框 | 分类代价 | 掩码代价 | Dice代价 | 总代价 |
---|---|---|---|---|---|
0 | 0 | -0.0421 | 0.4982 | 0.8635 | 1.3196 |
1 | 1 | -0.6307 | 0.4982 | 0.8635 | 0.7310 |
3 | 2 | -1.9132 | 0.5000 | 0.9143 | -0.4989 |
Batch 1 匹配结果
预测框 | 真实框 | 分类代价 | 掩码代价 | Dice代价 | 总代价 |
---|---|---|---|---|---|
0 | 1 | -1.4476 | 0.5144 | 0.8917 | -0.0415 |
2 | 0 | 0.1767 | 0.5144 | 0.8917 | 1.5828 |
3 | 2 | -1.3510 | 0.5000 | 0.9143 | 0.0633 |
分析:
- 分类代价:
- 范围从-1.9132到0.1767,负值表示较好的类别预测
- 最佳匹配通常具有较低的分类代价
- 掩码代价:
- 稳定在0.4982到0.5144之间
- 表示预测掩码和真实掩码有约50%的重叠
- Dice代价:
- 范围在0.8635到0.9143之间
- 较高的值表示掩码匹配还有改进空间
分割场景测试
测试数据规模:
- Batch size: 2
- 预测掩码: [2, 3, 32, 32] (每个batch 3个预测)
- 真实掩码: [2, 2, 32, 32] (每个batch 2个目标)
Batch 0 匹配结果
预测框 | 真实框 | 分类代价 | 掩码代价 | Dice代价 | 总代价 |
---|---|---|---|---|---|
0 | 0 | -1.5893 | 0.4898 | 0.7704 | -0.3291 |
1 | 1 | -1.5812 | 0.4763 | 0.6464 | -0.4584 |
Batch 1 匹配结果
预测框 | 真实框 | 分类代价 | 掩码代价 | Dice代价 | 总代价 |
---|---|---|---|---|---|
0 | 1 | -0.9829 | 0.4777 | 0.6654 | 0.1601 |
1 | 0 | 0.4253 | 0.4966 | 0.7584 | 1.6803 |
分析:
- 分类代价:
- 范围从-1.5893到0.4253
- Batch 0的分类效果明显优于Batch 1
- 掩码代价:
- 范围在0.4763到0.4966之间
- 比目标检测场景略好,说明圆形掩码的匹配更准确
- Dice代价:
- 范围在0.6464到0.7704之间
- 明显低于目标检测场景,表明分割任务的掩码匹配质量更好
总结
匈牙利算法在DETR中的应用展示了经典算法在现代深度学习中的重要性。通过合理的代价设计和高效的实现,它能够有效解决目标检测和实例分割中的匹配问题。上述代码实现和结果表明,该方法能够准确地找到预测框和真实框之间的最优匹配,为模型训练提供可靠的监督信号。