这是mmrotate,anchor-base的head损失的计算部分
def loss(self,
cls_scores, # 类别得分,形状为(N, num_anchors * num_classes, H, W)
bbox_preds, # 每个尺度级别的边界框预测N代表尺度,形状为(N, num_anchors * 5, H, W)
gt_bboxes, # 每张图片的GT,形状为(num_gts, 5),格式为[cx, cy, w, h, a]a是角度
gt_labels, # 每个GT对应的类别索引
img_metas, # 每张图片的元信息,例如图片大小,缩放因子等
gt_bboxes_ignore=None): # 指定在计算损失时可以忽略的边界盒,默认为None
"""计算头部的损失。"""
# 获取多级别特征图的大小
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
# 确保特征图的数量与锚点生成器中的级别数量相同
assert len(featmap_sizes) == self.anchor_generator.num_levels
# 获取设备信息
device = cls_scores[0].device
# 获取所有尺度的anchor列表和anchor是否有效的判断bool列表
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
# 根据是否使用sigmoid分类来确定标签通道的数量
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
# 获取目标,包括标签和边界盒
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels)
# 如果目标为空,则返回None
if cls_reg_targets is None:
return None
# 分类和回归目标
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
# 根据是否采样来确定总样本数
num_total_samples = (
num_total_pos + num_total_neg if self.sampling else num_total_pos)
# 计算多个级别的锚点数量
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# 将所有级别的锚点和有效与否的bool列表拼接成一个张量
concat_anchor_list = []
for i, _ in enumerate(anchor_list):
concat_anchor_list.append(torch.cat(anchor_list[i]))
all_anchor_list = images_to_levels(concat_anchor_list,
num_level_anchors)
# 应用单个损失计算函数,计算分类和回归损失
losses_cls, losses_bbox = multi_apply(
self.loss_single,
cls_scores,
bbox_preds,
all_anchor_list,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_samples)
# 返回分类和回归损失的字典
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)