SimOTA简介
1.确定正样本候选区域(使用中心先验)
2.计算每个样本对每个真实框的Reg + Csloss(Loss aware)
3.使用每个真实框的预测样本确定它需要分配到的正样本数(Dynamic k)
获取与当前真实框的ciou前10的样本;
将这Top10样本的ciou求和取整,就为当前真实框的dynamic k,dynamic k最小保证为10这个数字并不敏感,在5-15之间几乎都没有影响
4.为每个真实框取loss最小的前dynamick个样本作为正样本
5.去掉同一个样本被分配到多个真实框的正样本的情况(全局信息)
代码解释
import torch
import torch.nn.functional as F
/*'''
# 输出把预测框,预测框的得分,类别进行拼接
output = torch.cat([reg_output, obj_output, cls_output], 1)
# 对输出进行解码 size 为 [1, w*h, 5+类别数],同时生成与之对应的网格 size 为 [1, w*h, 2]
output, grid = get_output_and_grid(
output, k, stride_this_level, xin[0].type()
)
x_shifts.append(grid[:, :, 0])
y_shifts.append(grid[:, :, 1])
# 每一个输出结果的下采样步长
expanded_strides.append(
torch.zeros(1, grid.shape[1])
.fill_(stride_this_level)
.type_as(xin[0])
)
'''*/
# 对输出进行解码 size 为 [1, w*h, 5+类别数],同时生成与之对应的网格 size 为 [1, w*h, 2]
def get_output_and_grid(output, k, stride, dtype):
batch_size = output.shape[0]
n_ch = 5 + 80 # COCO数据集类别为80,5为预测框坐标+预测框置信度
hsize, wsize = output.shape[-2:]
if grid.shape[2:4] != output.shape[2:4]:
# yv, xv 的 size 为 [hsize, wsize]
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
# xv, yv 与前面的位置相反,猜测应该是坐标轴上的位置和矩阵位置不同
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
output = output.permute(0, 3, 1, 2).reshape(
batch_size, hsize * wsize, -1
)
grid = grid.view(1, -1, 2)
# 对输出进行解码
output[..., :2] = (output[..., :2] + grid) * stride
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
return output, grid
@torch.no_grad()
def get_assignments(
self,
batch_idx, # simota 对输出的每张图片分别进行标签匹配,所以会传入 batch 的索引
num_gt, # 每张图片目标的个数
total_num_anchors, # 每张图片预测框的个数
gt_bboxes_per_image, # 每张图片的真实框
gt_classes, # 每张图片真实框对应的类别
bboxes_preds_per_image, # 每张图片的预测框
expanded_strides, # 预测框对应的下采样步长
x_shifts, # 网格的横坐标(下采样后的网格)
y_shifts, # 网格的纵坐标(用于正样本的粗略筛选,确定正样本候选区域)
cls_preds, # 输出的全部类别
obj_preds, # 输出的全部预测框的置信度
):
# 确定正样本候选区域
fg_mask, is_in_boxes_and_center = get_in_boxes_info(
gt_bboxes_per_image,
expanded_strides,
x_shifts,
y_shifts,
total_num_anchors,
num_gt,
)
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
cls_preds_ = cls_preds[batch_idx][fg_mask]
obj_preds_ = obj_preds[batch_idx][fg_mask]
# 正样本的数量
num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
# 计算真实框和正样本的ciou值
# size 为[num_gt, num_in_boxes_anchor]
pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)
# [num_gt, num_in_boxes_anchor, num_classes]
# 真实框的类别 one hot 编码
gt_cls_per_image = (
F.one_hot(gt_classes.to(torch.int64), self.num_classes)
.float()
.unsqueeze(1)
.repeat(1, num_in_boxes_anchor, 1)
)
# 进一步处理ciou损失
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
# 计算类别损失
with torch.cuda.amp.autocast(enabled=False):
cls_preds_ = (
cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
* obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
)
pair_wise_cls_loss = F.binary_cross_entropy(
cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
).sum(-1)
del cls_preds_
# 匹配损失计算
cost = (
pair_wise_cls_loss
+ 3.0 * pair_wise_ious_loss
+ 100000.0 * (~is_in_boxes_and_center)
)
(
num_fg,
gt_matched_classes,
pred_ious_this_matching,
matched_gt_inds,
) = dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
return (
gt_matched_classes,
fg_mask,
pred_ious_this_matching,
matched_gt_inds,
num_fg,
)
def get_in_boxes_info(
gt_bboxes_per_image,
expanded_strides,
x_shifts,
y_shifts,
total_num_anchors,
num_gt,
):
# 每张图片对应的下采样步长
expanded_strides_per_image = expanded_strides[0]
# 坐标扩展到原图大小
x_shifts_per_image = x_shifts[0] * expanded_strides_per_image
y_shifts_per_image = y_shifts[0] * expanded_strides_per_image
# 原始坐标是网格的左上坐标,将其移动到中点 [n_anchor] -> [n_gt, n_anchor]
# 锚点的原始坐标是网格的左上坐标,将其移动到中点
x_centers_per_image = (
(x_shifts_per_image + 0.5 * expanded_strides_per_image)
.unsqueeze(0)
.repeat(num_gt, 1)
)
y_centers_per_image = (
(y_shifts_per_image + 0.5 * expanded_strides_per_image)
.unsqueeze(0)
.repeat(num_gt, 1)
)
# 真实框的4个坐标为,中心点(x,y)和宽高(w,h), 将中心点分别向左上和右下移动0.5倍的w或h,形成一个框。
# size 为 [真实框个数, 预测框个数]
gt_bboxes_per_image_l = (
(gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2])
.unsqueeze(1)
.repeat(1, total_num_anchors)
)
gt_bboxes_per_image_r = (
(gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2])
.unsqueeze(1)
.repeat(1, total_num_anchors)
)
gt_bboxes_per_image_t = (
(gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3])
.unsqueeze(1)
.repeat(1, total_num_anchors)
)
gt_bboxes_per_image_b = (
(gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3])
.unsqueeze(1)
.repeat(1, total_num_anchors)
)
# 判断锚点是否在上面真实框中心坐标移动形成的框内
b_l = x_centers_per_image - gt_bboxes_per_image_l
b_r = gt_bboxes_per_image_r - x_centers_per_image
b_t = y_centers_per_image - gt_bboxes_per_image_t
b_b = gt_bboxes_per_image_b - y_centers_per_image
bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
# 判断锚点是否为候选框
is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
# 每个真实框是否都有正样本
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
# in fixed center
center_radius = 2.5
# 将真实框中心点分别向左上和右下移动0.5倍的下采样步长,形成一个框。操作和上面一样不在解释
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
1, total_num_anchors
) - center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
1, total_num_anchors
) + center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
1, total_num_anchors
) - center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
1, total_num_anchors
) + center_radius * expanded_strides_per_image.unsqueeze(0)
c_l = x_centers_per_image - gt_bboxes_per_image_l
c_r = gt_bboxes_per_image_r - x_centers_per_image
c_t = y_centers_per_image - gt_bboxes_per_image_t
c_b = gt_bboxes_per_image_b - y_centers_per_image
center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
is_in_centers = center_deltas.min(dim=-1).values > 0.0
is_in_centers_all = is_in_centers.sum(dim=0) > 0
# in boxes and in centers
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
is_in_boxes_and_center = (
is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
)
# is_in_boxes_anchor:用于预测框(正样本)的筛选
# is_in_boxes_and_center:用于真实框和正样本的对齐
return is_in_boxes_anchor, is_in_boxes_and_center
def dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
# Dynamic K
# ---------------------------------------------------------------
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
# ciou 值
ious_in_boxes_matrix = pair_wise_ious
# 确定k值,用于选取前k个ciou值,用于后面动态k的计算
n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
# 筛选出每个真实框与之匹配的k个ciou值最高的正样本
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
# 根据筛选出的ciou值确定动态k,用于筛选匹配损失最低的正样本
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
dynamic_ks = dynamic_ks.tolist()
# 为每个真实框筛选出动态k个正样本
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(
cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
)
matching_matrix[gt_idx][pos_idx] = 1
del topk_ious, dynamic_ks, pos_idx
# 每个正样本与之匹配的真实框的个数
anchor_matching_gt = matching_matrix.sum(0)
# 对于一个正样本匹配多个真实框的情况进行处理
if (anchor_matching_gt > 1).sum() > 0:
# cost[:, anchor_matching_gt > 1] :筛选出一个正样本匹配多个真实框的匹配损失
# cost_argmin : 为与正样本匹配损失最小的真实框的位置
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
matching_matrix[:, anchor_matching_gt > 1] *= 0
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
# 正样本的 mask
fg_mask_inboxes = matching_matrix.sum(0) > 0
# 正样本的数量
num_fg = fg_mask_inboxes.sum().item()
fg_mask[fg_mask.clone()] = fg_mask_inboxes
# 与正样本相匹配的真实框的 index
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
# 与真实框对应的类别
gt_matched_classes = gt_classes[matched_gt_inds]
# 经过筛选后真实框与正样本的ciou
pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
fg_mask_inboxes
]
return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
import math
def bboxes_iou(pred, target, xyxy=True):
w1 = pred[:,None, 2]
h1 = pred[:, None,3]
w2 = target[:, 2]
h2 = target[:, 3]
area1 = w1 * h1
area2 = w2 * h2
center_x1 = pred[:,None, 0]
center_y1 = pred[:,None, 1]
center_x2 = target[:, 0]
center_y2 = target[:, 1]
inter_min_xy = torch.max(
(pred[:,None, :2] - pred[:,None, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
inter_max_xy = torch.min(
(pred[:,None, :2] + pred[:,None, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
out_min_xy = torch.min(
(pred[:,None, :2] - pred[:,None, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
out_max_xy = torch.max(
(pred[:,None, :2] + pred[:, None,2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
inter = torch.clamp((inter_max_xy - inter_min_xy), min=0)
inter_area = inter[:, 0] * inter[:, 1]
inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2
outer = torch.clamp((out_max_xy - out_min_xy), min=0)
outer_diag = (outer[:, 0] ** 2) + (outer[:, 1] ** 2)
union = area1+area2-inter_area
u = (inter_diag) / outer_diag
iou = inter_area / union
with torch.no_grad():
arctan = torch.atan(w2 / h2) - torch.atan(w1 / h1)
v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(w2 / h2) - torch.atan(w1 / h1)), 2)
S = 1 - iou
alpha = v / (S + v)
w_temp = 2 * w1
ar = (8 / (math.pi ** 2)) * arctan * ((w1 - w_temp) * h1)
cious = iou - (u + alpha * ar)
cious = torch.clamp(cious,min=-1.0,max = 1.0)
return cious