MMDetection系列之(自定义损失函数)

MMDetection为用户提供不同的丢失功能。但是默认配置可能不适用于不同的数据集或模型,因此用户可能希望修改特定的损失以适应新的情况。

计算管道的损失

给定输入预测和目标,以及权值,损失函数将输入张量映射到最终损失标量。映射可以分为四个步骤:

  1. 设置采样方式为正采样和负采样。
  2. 通过损失核函数获得元素或样本的损失。
    3.明智地用一个权张量元素对损失进行加权。
  3. 将损失张量降为标量。
  4. 用标量对损失进行加权。

1、Set sampling method (step 1)

对于某些损失函数,需要采取抽样策略来避免正样本和负样本之间的不平衡。
例如,当在RPN头部使用CrossEntropyLoss时,我们需要在train_cfg中设置RandomSampler

train_cfg=dict(
    rpn=dict(
        sampler=dict(
            type='RandomSampler',
            num=256,
            pos_fraction=0.5,
            neg_pos_ub=-1,
            add_gt_as_proposals=False))

对于具有正负样品平衡机制的其他损失,如FocalLoss、GHMC和QualityFocalLoss,则不再需要采样器。

Tweaking loss

调整损失与步骤2、步骤4、步骤5更相关,大多数修改都可以在配置中指定。这里我们以Focal Loss (FL)为例。下面的代码狙击手分别是FL的构造方法和配置,它们实际上是一一对应的。

@LOSSES.register_module()
class FocalLoss(nn.Module):

    def __init__(self,
                 use_sigmoid=True,
                 gamma=2.0,
                 alpha=0.25,
                 reduction='mean',
                 loss_weight=1.0):
loss_cls=dict(
    type='FocalLoss',
    use_sigmoid=True,
    gamma=2.0,
    alpha=0.25,
    loss_weight=1.0)

2、Tweaking hyper-parameters (step 2)

gamma和beta是焦损的两个超参数。假设我们想要将gamma值改为1.5,alpha值改为0.5,那么我们可以在配置中如下所示指定它们:

loss_cls=dict(
    type='FocalLoss',
    use_sigmoid=True,
    gamma=1.5,
    alpha=0.5,
    loss_weight=1.0)

Tweaking the way of reduction (step 3)

对于FL,约简的默认方式是mean。例如,如果我们想要将约简从mean更改为sum,我们可以在配置中指定如下:

loss_cls=dict(
    type='FocalLoss',
    use_sigmoid=True,
    gamma=2.0,
    alpha=0.25,
    loss_weight=1.0,
    reduction='sum')

Tweaking loss weight (step 5)

这里的损失权值是一个标量,它控制多任务学习中不同损失的权值,例如分类损失和回归损失。例如,如果我们想将分类损失的损失权重改为0.5,我们可以在配置中如下所示:

loss_cls=dict(
    type='FocalLoss',
    use_sigmoid=True,
    gamma=2.0,
    alpha=0.25,
    loss_weight=0.5)

Weighting loss (step 3)

加权损失意味着我们明智地重新加权损失元素。更具体地说,我们将损失张量与一个具有相同形状的权张量相乘。因此,损失的不同分项可以按不同的比例计算,这就是所谓的“要素”。损失权值在不同的模型中是不同的,并且与上下文高度相关,但总体上有两种损失权值,用于分类损失的label_weights和用于框回归损失的bbox_weights。您可以在相应头部的get_target方法中找到它们。这里我们以atshead为例,它继承了AnchorHead但是覆盖了它的get_targets方法,从而产生了不同的label_weights和bbox_weights。

lass ATSSHead(AnchorHead):

    ...

    def get_targets(self,
                    anchor_list,
                    valid_flag_list,
                    gt_bboxes_list,
                    img_metas,
                    gt_bboxes_ignore_list=None,
                    gt_labels_list=None,
                    label_channels=1,
                    unmap_outputs=True):
  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
根据官方文档资料的自定义模型部分,MMDetection允许用户通过组合不同的模块组件来构建自定义的检测模型。MMDetection提供了丰富的即插即用的算法和模型,支持众多主流和最新的检测算法,比如Faster R-CNN等。然而,网上对于MMDetection的资料还相对较少,但有一篇博客提供了关于如何使用MMDetection替换自己实现的backbone结构的经验,该博客记录了作者如何设计的backbone结构来替换DETR模型的ResNet,并使用口罩检测数据集进行训练。 因此,MMDetection的自定义模型可以通过组合不同的模块组件来实现,并且有丰富的算法和模型可供选择,同时我们可以根据需要替换掉已有的backbone结构来进行训练。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [MMDetection系列 | 3. MMDetection定义模型训练](https://blog.csdn.net/weixin_44751294/article/details/126804581)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* [MMDetection实战:MMDetection训练与测试](https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/85331635)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值