DINO中的匈牙利匹配算法

模型输出的output

 num_query = 900 类别为 7

pred_boxes 为目标框预测   out_bbox (900,7)

pred_logits 经过sigmoid后得到各类别分类概率  out_prob (900,4)

先计算得到label所在所在分类tgt_ids的损失 cost_class 

alpha = self.focal_alpha
gamma = 2.0
neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]

检测框的损失cost_bbox cost_giou

cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

cost_giou = generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),

box_cxcywh_to_xyxy(tgt_bbox))

最后将cost_class   cost_bbox  cost_giou 乘上各自权重(超参数)得到总的cost
生成的cost 有 n = 900 个损失 实际的bbox 为 m (m远小于n)cost (bs,num_query, -1)

构建出 n行 m列 的矩阵,每一列(代表一个bbox)会选择一个cost最小的位置

 

 indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]

返回值为坐标 如上图 返回值为 (2,1) (4,2)(6,3)

根据坐标可从n个query中选出m个query与target做损失函数

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值