Yolox标签匹配算法SimOTA原理及代码解释

SimOTA简介

        1.确定正样本候选区域(使用中心先验)
        2.计算每个样本对每个真实框的Reg + Csloss(Loss aware)
        3.使用每个真实框的预测样本确定它需要分配到的正样本数(Dynamic k)
            获取与当前真实框的ciou前10的样本;
            将这Top10样本的ciou求和取整,就为当前真实框的dynamic k,dynamic k最小保证为10这个数字并不敏感,在5-15之间几乎都没有影响
        4.为每个真实框取loss最小的前dynamick个样本作为正样本
        5.去掉同一个样本被分配到多个真实框的正样本的情况(全局信息)

代码解释

import torch
import torch.nn.functional as F

/*'''
# 输出把预测框,预测框的得分,类别进行拼接
output = torch.cat([reg_output, obj_output, cls_output], 1)

# 对输出进行解码 size 为 [1, w*h, 5+类别数],同时生成与之对应的网格 size 为 [1, w*h, 2]
output, grid = get_output_and_grid(
    output, k, stride_this_level, xin[0].type()
)
x_shifts.append(grid[:, :, 0])
y_shifts.append(grid[:, :, 1])

# 每一个输出结果的下采样步长
expanded_strides.append(
    torch.zeros(1, grid.shape[1])
    .fill_(stride_this_level)
    .type_as(xin[0])
)
'''*/
# 对输出进行解码 size 为 [1, w*h, 5+类别数],同时生成与之对应的网格 size 为 [1, w*h, 2]
def get_output_and_grid(output, k, stride, dtype):

    batch_size = output.shape[0]
    n_ch = 5 + 80 # COCO数据集类别为805为预测框坐标+预测框置信度
    hsize, wsize = output.shape[-2:]
    if grid.shape[2:4] != output.shape[2:4]:
        # yv, xv 的 size 为 [hsize, wsize]
        yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
        # xv, yv 与前面的位置相反,猜测应该是坐标轴上的位置和矩阵位置不同
        grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)

    output = output.permute(0, 3, 1, 2).reshape(
        batch_size, hsize * wsize, -1
    )
    grid = grid.view(1, -1, 2)
    # 对输出进行解码
    output[..., :2] = (output[..., :2] + grid) * stride
    output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
    return output, grid

@torch.no_grad()
def get_assignments(
    self,
    batch_idx, # simota 对输出的每张图片分别进行标签匹配,所以会传入 batch 的索引
    num_gt, # 每张图片目标的个数
    total_num_anchors, # 每张图片预测框的个数
    gt_bboxes_per_image, # 每张图片的真实框
    gt_classes, # 每张图片真实框对应的类别
    bboxes_preds_per_image, # 每张图片的预测框
    expanded_strides, # 预测框对应的下采样步长
    x_shifts, # 网格的横坐标(下采样后的网格)
    y_shifts, # 网格的纵坐标(用于正样本的粗略筛选,确定正样本候选区域)
    cls_preds, # 输出的全部类别
    obj_preds, # 输出的全部预测框的置信度
):
    # 确定正样本候选区域
    fg_mask, is_in_boxes_and_center = get_in_boxes_info(
        gt_bboxes_per_image,
        expanded_strides,
        x_shifts,
        y_shifts,
        total_num_anchors,
        num_gt,
    )

    bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
    cls_preds_ = cls_preds[batch_idx][fg_mask]
    obj_preds_ = obj_preds[batch_idx][fg_mask]
    # 正样本的数量
    num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
    # 计算真实框和正样本的ciou值
    # size 为[num_gt, num_in_boxes_anchor]
    pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)

    # [num_gt, num_in_boxes_anchor, num_classes]
    # 真实框的类别 one hot 编码
    gt_cls_per_image = (
        F.one_hot(gt_classes.to(torch.int64), self.num_classes)
        .float()
        .unsqueeze(1)
        .repeat(1, num_in_boxes_anchor, 1)
    )

    # 进一步处理ciou损失
    pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)

    # 计算类别损失
    with torch.cuda.amp.autocast(enabled=False):
        cls_preds_ = (
            cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
            * obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
        )
        pair_wise_cls_loss = F.binary_cross_entropy(
            cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
        ).sum(-1)
    del cls_preds_
    
    # 匹配损失计算 
    cost = (
        pair_wise_cls_loss
        + 3.0 * pair_wise_ious_loss
        + 100000.0 * (~is_in_boxes_and_center)
    )

    (
        num_fg,
        gt_matched_classes,
        pred_ious_this_matching,
        matched_gt_inds,
    ) = dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
    del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss

    return (
        gt_matched_classes,
        fg_mask,
        pred_ious_this_matching,
        matched_gt_inds,
        num_fg,
    )

def get_in_boxes_info(
    gt_bboxes_per_image,
    expanded_strides,
    x_shifts,
    y_shifts,
    total_num_anchors,
    num_gt,
):
    # 每张图片对应的下采样步长
    expanded_strides_per_image = expanded_strides[0]
    # 坐标扩展到原图大小
    x_shifts_per_image = x_shifts[0] * expanded_strides_per_image
    y_shifts_per_image = y_shifts[0] * expanded_strides_per_image

    # 原始坐标是网格的左上坐标,将其移动到中点 [n_anchor] -> [n_gt, n_anchor]
    # 锚点的原始坐标是网格的左上坐标,将其移动到中点
    x_centers_per_image = (
        (x_shifts_per_image + 0.5 * expanded_strides_per_image)
        .unsqueeze(0)
        .repeat(num_gt, 1)
    )  
    y_centers_per_image = (
        (y_shifts_per_image + 0.5 * expanded_strides_per_image)
        .unsqueeze(0)
        .repeat(num_gt, 1)
    )
    
    # 真实框的4个坐标为,中心点(x,y)和宽高(w,h), 将中心点分别向左上和右下移动0.5倍的w或h,形成一个框。
    # size 为 [真实框个数, 预测框个数]
    gt_bboxes_per_image_l = (
        (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2])
        .unsqueeze(1)
        .repeat(1, total_num_anchors)
    )
    gt_bboxes_per_image_r = (
        (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2])
        .unsqueeze(1)
        .repeat(1, total_num_anchors)
    )
    gt_bboxes_per_image_t = (
        (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3])
        .unsqueeze(1)
        .repeat(1, total_num_anchors)
    )
    gt_bboxes_per_image_b = (
        (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3])
        .unsqueeze(1)
        .repeat(1, total_num_anchors)
    )
    # 判断锚点是否在上面真实框中心坐标移动形成的框内 
    b_l = x_centers_per_image - gt_bboxes_per_image_l
    b_r = gt_bboxes_per_image_r - x_centers_per_image
    b_t = y_centers_per_image - gt_bboxes_per_image_t
    b_b = gt_bboxes_per_image_b - y_centers_per_image
    bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)

    # 判断锚点是否为候选框
    is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
    # 每个真实框是否都有正样本
    is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
    # in fixed center

    center_radius = 2.5

    # 将真实框中心点分别向左上和右下移动0.5倍的下采样步长,形成一个框。操作和上面一样不在解释
    gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
        1, total_num_anchors
    ) - center_radius * expanded_strides_per_image.unsqueeze(0)
    gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
        1, total_num_anchors
    ) + center_radius * expanded_strides_per_image.unsqueeze(0)
    gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
        1, total_num_anchors
    ) - center_radius * expanded_strides_per_image.unsqueeze(0)
    gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
        1, total_num_anchors
    ) + center_radius * expanded_strides_per_image.unsqueeze(0)

    c_l = x_centers_per_image - gt_bboxes_per_image_l
    c_r = gt_bboxes_per_image_r - x_centers_per_image
    c_t = y_centers_per_image - gt_bboxes_per_image_t
    c_b = gt_bboxes_per_image_b - y_centers_per_image
    center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
    is_in_centers = center_deltas.min(dim=-1).values > 0.0
    is_in_centers_all = is_in_centers.sum(dim=0) > 0

    # in boxes and in centers
    is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all

    is_in_boxes_and_center = (
        is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
    )
    # is_in_boxes_anchor:用于预测框(正样本)的筛选
    # is_in_boxes_and_center:用于真实框和正样本的对齐
    return is_in_boxes_anchor, is_in_boxes_and_center

def dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
    # Dynamic K
    # ---------------------------------------------------------------
    matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
    # ciou 值
    ious_in_boxes_matrix = pair_wise_ious
    # 确定k值,用于选取前k个ciou值,用于后面动态k的计算
    n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
    # 筛选出每个真实框与之匹配的k个ciou值最高的正样本
    topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
    # 根据筛选出的ciou值确定动态k,用于筛选匹配损失最低的正样本
    dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
    dynamic_ks = dynamic_ks.tolist()
    # 为每个真实框筛选出动态k个正样本
    for gt_idx in range(num_gt):
        _, pos_idx = torch.topk(
            cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
        )
        matching_matrix[gt_idx][pos_idx] = 1

    del topk_ious, dynamic_ks, pos_idx

    # 每个正样本与之匹配的真实框的个数
    anchor_matching_gt = matching_matrix.sum(0)
    # 对于一个正样本匹配多个真实框的情况进行处理
    if (anchor_matching_gt > 1).sum() > 0:
        # cost[:, anchor_matching_gt > 1] :筛选出一个正样本匹配多个真实框的匹配损失
        # cost_argmin : 为与正样本匹配损失最小的真实框的位置
        _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
        matching_matrix[:, anchor_matching_gt > 1] *= 0
        matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
    # 正样本的 mask
    fg_mask_inboxes = matching_matrix.sum(0) > 0
    # 正样本的数量
    num_fg = fg_mask_inboxes.sum().item()

    fg_mask[fg_mask.clone()] = fg_mask_inboxes

    # 与正样本相匹配的真实框的 index
    matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
    # 与真实框对应的类别
    gt_matched_classes = gt_classes[matched_gt_inds]
    # 经过筛选后真实框与正样本的ciou
    pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
        fg_mask_inboxes
    ]
    return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds

import math
def bboxes_iou(pred, target, xyxy=True):
    w1 = pred[:,None, 2]
    h1 = pred[:, None,3]
    w2 = target[:, 2]
    h2 = target[:, 3]

    area1 = w1 * h1
    area2 = w2 * h2

    center_x1 = pred[:,None, 0]
    center_y1 = pred[:,None, 1]
    center_x2 = target[:, 0]
    center_y2 = target[:, 1]

    inter_min_xy = torch.max(
        (pred[:,None, :2] - pred[:,None, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
    )
    inter_max_xy = torch.min(
        (pred[:,None, :2] + pred[:,None, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
    )
    out_min_xy = torch.min(
        (pred[:,None, :2] - pred[:,None, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
    )
    out_max_xy = torch.max(
        (pred[:,None, :2] + pred[:, None,2:] / 2), (target[:, :2] + target[:, 2:] / 2)
    )

    inter = torch.clamp((inter_max_xy - inter_min_xy), min=0)
    inter_area = inter[:, 0] * inter[:, 1]
    inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2
    outer = torch.clamp((out_max_xy - out_min_xy), min=0)
    outer_diag = (outer[:, 0] ** 2) + (outer[:, 1] ** 2)
    union = area1+area2-inter_area
    u = (inter_diag) / outer_diag
    iou = inter_area / union
    with torch.no_grad():
        arctan = torch.atan(w2 / h2) - torch.atan(w1 / h1)
        v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(w2 / h2) - torch.atan(w1 / h1)), 2)
        S = 1 - iou
        alpha = v / (S + v)
        w_temp = 2 * w1
    ar = (8 / (math.pi ** 2)) * arctan * ((w1 - w_temp) * h1)
    cious = iou - (u + alpha * ar)
    cious = torch.clamp(cious,min=-1.0,max = 1.0)
    return cious

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值