yolox 的无锚的操作流程和正样本匹配
1.头的解耦
self.stems
self.reg_preds
self.obj_preds
分别进行预测,利用普通的卷积操作。
2.样本匹配
# 属于正样本的特征点会落在物体真实框内部,特征点中心与物体真实框中心要相近
def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt, center_radius = 2.5):
#-------------------------------------------------------#
# expanded_strides_per_image [n_anchors_all]
# x_centers_per_image [num_gt, n_anchors_all]
# x_centers_per_image [num_gt, n_anchors_all]
#-------------------------------------------------------#
expanded_strides_per_image = expanded_strides[0]
x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
#-------------------------------------------------------#
# gt_bboxes_per_image_x [num_gt, n_anchors_all]
# 一共 8400 个锚框 锚框上对应的gt: l r t b
#-------------------------------------------------------#
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
#-------------------------------------------------------#
# bbox_deltas [num_gt, n_anchors_all, 4]
# x_centers_per_image 特征点一定要再锚框内
#-------------------------------------------------------#
b_l = x_centers_per_image - gt_bboxes_per_image_l
b_r = gt_bboxes_per_image_r - x_centers_per_image
b_t = y_centers_per_image - gt_bboxes_per_image_t
b_b = gt_bboxes_per_image_b - y_centers_per_image
bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
#-------------------------------------------------------#
# is_in_boxes [num_gt, n_anchors_all]
# is_in_boxes_all [n_anchors_all]
# 1、特征点落在物体的真实框内。
#-------------------------------------------------------#
is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
# 2、特征点距离物体中心尽量要在一定半径内
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
#-------------------------------------------------------#
# center_deltas [num_gt, n_anchors_all, 4]
#-------------------------------------------------------#
c_l = x_centers_per_image - gt_bboxes_per_image_l
c_r = gt_bboxes_per_image_r - x_centers_per_image
c_t = y_centers_per_image - gt_bboxes_per_image_t
c_b = gt_bboxes_per_image_b - y_centers_per_image
center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
#-------------------------------------------------------#
# is_in_centers [num_gt, n_anchors_all]
# is_in_centers_all [n_anchors_all]
#-------------------------------------------------------#
is_in_centers = center_deltas.min(dim=-1).values > 0.0
is_in_centers_all = is_in_centers.sum(dim=0) > 0
#-------------------------------------------------------#
# is_in_boxes_anchor [n_anchors_all]
# is_in_boxes_and_center [num_gt, is_in_boxes_anchor]
#-------------------------------------------------------#
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
is_in_boxes_and_center = is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
return is_in_boxes_anchor, is_in_boxes_and_center
通过计算cost代价函数 ,动态匹配正样本
代价矩阵是由 类的正样本损失 + iou正样本损失组成
# Cost 代价矩阵 将最低的k 个点作为该真实框的正样本点
cost = pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center).float()
#-------------------------------------------------------#
# cost [num_gt, fg_mask]
# pair_wise_ious [num_gt, fg_mask]
# gt_classes [num_gt]
# fg_mask [n_anchors_all]
# matching_matrix [num_gt, fg_mask]
#-------------------------------------------------------#
matching_matrix = torch.zeros_like(cost)
#------------------------------------------------------------#
# 选取iou最大的n_candidate_k个点
# 然后求和,判断应该有多少点用于该框预测
# topk_ious [num_gt, n_candidate_k]
# dynamic_ks [num_gt]
# matching_matrix [num_gt, fg_mask]
#------------------------------------------------------------#
n_candidate_k = min(10, pair_wise_ious.size(1))
topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
for gt_idx in range(num_gt):
#------------------------------------------------------------#
# 给每个真实框选取最小的动态k个点
#------------------------------------------------------------#
_, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
matching_matrix[gt_idx][pos_idx] = 1.0
del topk_ious, dynamic_ks, pos_idx
#------------------------------------------------------------#
# anchor_matching_gt [fg_mask]
#------------------------------------------------------------#
anchor_matching_gt = matching_matrix.sum(0)
if (anchor_matching_gt > 1).sum() > 0:
#------------------------------------------------------------#
# 当某一个特征点指向多个真实框的时候
# 选取cost最小的真实框。
#------------------------------------------------------------#
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
#------------------------------------------------------------#
# fg_mask_inboxes [fg_mask]
# num_fg为正样本的特征点个数
#------------------------------------------------------------#
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
num_fg = fg_mask_inboxes.sum().item()
#------------------------------------------------------------#
# 对fg_mask进行更新
#------------------------------------------------------------#
fg_mask[fg_mask.clone()] = fg_mask_inboxes
#------------------------------------------------------------#
# 获得特征点对应的物品种类
#------------------------------------------------------------#
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
gt_matched_classes = gt_classes[matched_gt_inds]
pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes]
return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
动态匹配