核心代码注释
# cost [n_gt, n_anchor]
# pair_wise_ious [n_gt, n_anchor]
# gt_classes [num_gt]
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
# Dynamic K
# ---------------------------------------------------------------
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
ious_in_boxes_matrix = pair_wise_ious
n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
# [n_gt, <10]
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
# [n_gt] min=1表示至少匹配一个框
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
dynamic_ks = dynamic_ks.tolist()
# 对于每个gt框, 放大后, 锚框中点如果被覆盖. 计算iou的时候就累加进来.
# 一个gt框匹配多少个锚框, 由最后的iou累加结果定.
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(
cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
)
matching_matrix[gt_idx][pos_idx] = 1
del topk_ious, dynamic_ks, pos_idx
# 一个锚框只能匹配一个gt框, cost最小的优先匹配
# anchor_matching_gt [n_anchor]
anchor_matching_gt = matching_matrix.sum(0)
if (anchor_matching_gt > 1).sum() > 0:
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
matching_matrix[:, anchor_matching_gt > 1] *= 0
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
# [n_anchor] 锚框是否有匹配到gt框
fg_mask_inboxes = matching_matrix.sum(0) > 0
# 锚框数量(匹配到gt框的)
num_fg = fg_mask_inboxes.sum().item()
# fg_mask: [total_num_anchors] 锚框中点被gt框或放大后的gt框覆盖
# fg_mask更新
fg_mask[fg_mask.clone()] = fg_mask_inboxes
# 锚框匹配的gt框的id
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
# gt_classes, # [num_gt] 标识类别
# 锚框匹配的gt框的类别
gt_matched_classes = gt_classes[matched_gt_inds]
# 锚框与匹配的gt框的iou
pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
fg_mask_inboxes
]
return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
更多上下文代码注释详见
http://bolun365.github.io/codes/OD-loss(YOLOX)/