YOLOX代码 loss计算过程 详细注释版

核心代码注释


    # cost [n_gt, n_anchor]
    # pair_wise_ious [n_gt, n_anchor]
    # gt_classes     [num_gt]
    def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
        # Dynamic K
        # ---------------------------------------------------------------
        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)

        ious_in_boxes_matrix = pair_wise_ious
        n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
        # [n_gt, <10]
        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
        # [n_gt] min=1表示至少匹配一个框
        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
        dynamic_ks = dynamic_ks.tolist()
        # 对于每个gt框, 放大后, 锚框中点如果被覆盖. 计算iou的时候就累加进来.
        # 一个gt框匹配多少个锚框, 由最后的iou累加结果定.
        for gt_idx in range(num_gt):
            _, pos_idx = torch.topk(
                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
            )
            matching_matrix[gt_idx][pos_idx] = 1

        del topk_ious, dynamic_ks, pos_idx

        # 一个锚框只能匹配一个gt框, cost最小的优先匹配
        # anchor_matching_gt [n_anchor]
        anchor_matching_gt = matching_matrix.sum(0)
        if (anchor_matching_gt > 1).sum() > 0:
            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
            matching_matrix[:, anchor_matching_gt > 1] *= 0
            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
        # [n_anchor] 锚框是否有匹配到gt框
        fg_mask_inboxes = matching_matrix.sum(0) > 0
        # 锚框数量(匹配到gt框的)
        num_fg = fg_mask_inboxes.sum().item()

        # fg_mask:                 [total_num_anchors] 锚框中点被gt框或放大后的gt框覆盖
        # fg_mask更新
        fg_mask[fg_mask.clone()] = fg_mask_inboxes

        # 锚框匹配的gt框的id
        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
        # gt_classes,                 # [num_gt]       标识类别
        # 锚框匹配的gt框的类别
        gt_matched_classes = gt_classes[matched_gt_inds]

        # 锚框与匹配的gt框的iou
        pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
            fg_mask_inboxes
        ]
        return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds

更多上下文代码注释详见

http://bolun365.github.io/codes/OD-loss(YOLOX)/

Od loss(yolox) - Deeper Learning

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值