mmrotate中的rotate_anchor_head.py—loss()部分

这是mmrotate,anchor-base的head损失的计算部分

def loss(self,
         cls_scores, # 类别得分,形状为(N, num_anchors * num_classes, H, W)
         bbox_preds, # 每个尺度级别的边界框预测N代表尺度,形状为(N, num_anchors * 5, H, W)
         gt_bboxes, # 每张图片的GT,形状为(num_gts, 5),格式为[cx, cy, w, h, a]a是角度
         gt_labels, # 每个GT对应的类别索引
         img_metas, # 每张图片的元信息,例如图片大小,缩放因子等
         gt_bboxes_ignore=None): # 指定在计算损失时可以忽略的边界盒,默认为None
    """计算头部的损失。"""
    
    # 获取多级别特征图的大小 
    featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
    # 确保特征图的数量与锚点生成器中的级别数量相同
    assert len(featmap_sizes) == self.anchor_generator.num_levels

    # 获取设备信息
    device = cls_scores[0].device

    # 获取所有尺度的anchor列表和anchor是否有效的判断bool列表
    anchor_list, valid_flag_list = self.get_anchors(
        featmap_sizes, img_metas, device=device)
    # 根据是否使用sigmoid分类来确定标签通道的数量
    label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
    # 获取目标,包括标签和边界盒
    cls_reg_targets = self.get_targets(
        anchor_list,
        valid_flag_list,
        gt_bboxes,
        img_metas,
        gt_bboxes_ignore_list=gt_bboxes_ignore,
        gt_labels_list=gt_labels,
        label_channels=label_channels)
    # 如果目标为空,则返回None
    if cls_reg_targets is None:
        return None
    # 分类和回归目标
    (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
     num_total_pos, num_total_neg) = cls_reg_targets
    # 根据是否采样来确定总样本数
    num_total_samples = (
        num_total_pos + num_total_neg if self.sampling else num_total_pos)

    # 计算多个级别的锚点数量
    num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
    # 将所有级别的锚点和有效与否的bool列表拼接成一个张量
    concat_anchor_list = []
    for i, _ in enumerate(anchor_list):
        concat_anchor_list.append(torch.cat(anchor_list[i]))
    all_anchor_list = images_to_levels(concat_anchor_list,
                                       num_level_anchors)

    # 应用单个损失计算函数,计算分类和回归损失
    losses_cls, losses_bbox = multi_apply(
        self.loss_single,
        cls_scores,
        bbox_preds,
        all_anchor_list,
        labels_list,
        label_weights_list,
        bbox_targets_list,
        bbox_weights_list,
        num_total_samples=num_total_samples)
    # 返回分类和回归损失的字典
    return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值