模型输出的output
num_query = 900 类别为 7
pred_boxes 为目标框预测 out_bbox (900,7)
pred_logits 经过sigmoid后得到各类别分类概率 out_prob (900,4)
先计算得到label所在所在分类tgt_ids的损失 cost_class
alpha = self.focal_alpha
gamma = 2.0
neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
检测框的损失cost_bbox 和 cost_giou
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
cost_giou = generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
box_cxcywh_to_xyxy(tgt_bbox))
最后将cost_class cost_bbox cost_giou 乘上各自权重(超参数)得到总的cost
生成的cost 有 n = 900 个损失 实际的bbox 为 m (m远小于n)cost (bs,num_query, -1)
构建出 n行 m列 的矩阵,每一列(代表一个bbox)会选择一个cost最小的位置
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
返回值为坐标 如上图 返回值为 (2,1) (4,2)(6,3)
根据坐标可从n个query中选出m个query与target做损失函数