mmdetection2 定制Mode篇

mmdetection Model篇

mmdetection主要将模型分为5大类,分别为backbone(ResNet, MobileNet)、neck(FPN, PAFPN)、head(bbox prediction and mask prediction)、roi extractor(RoI Align)以及loss(FocalLoss, L1Loss, and GHMLoss)部分。


一、Develop new components?

添加一个新的backbone需要三个步骤:

1. 定义一个新的backbone

mmdet/models/backbones/mobilenet.py 中创建一个新的文件

import torch.nn as nn

from ..builder import BACKBONES


@BACKBONES.register_module()
class MobileNet(nn.Module):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # should return a tuple
        pass

    def init_weights(self, pretrained=None):
        pass

2. 导入这个module

mmdetection给了两种导入方式,一种是在

mmdet/models/backbones/init.py 中使用

from .mobilenet import MobileNet

这里我们可以看一下当前init.py文件的一个配置情况:

from .darknet import Darknet
from .detectors_resnet import DetectoRS_ResNet
from .detectors_resnext import DetectoRS_ResNeXt
from .hourglass import HourglassNet
from .hrnet import HRNet
from .regnet import RegNet
from .res2net import Res2Net
from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1d
from .resnext import ResNeXt
from .ssd_vgg import SSDVGG
from .trident_resnet import TridentResNet

__all__ = [
    'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net',
    'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet',
    'ResNeSt', 'TridentResNet'
]

在这里插入图片描述

或者尝试添加代码到config文件中避免更改原始的code

custom_imports = dict(
    imports=['mmdet.models.backbones.mobilenet'],
    allow_failed_imports=False)

3. 在自己的config文件中使用该backbone

model = dict(
    ...
    backbone=dict(
        type='MobileNet',
        arg1=xxx,
        arg2=xxx),
    ...

二、Add new necks

1.定义一个neck

修改文件位置在:

mmdet/models/necks/pafpn.py

from ..builder import NECKS

@NECKS.register_module()
class PAFPN(nn.Module):

    def __init__(self,
                in_channels,
                out_channels,
                num_outs,
                start_level=0,
                end_level=-1,
                add_extra_convs=False):
        pass

    def forward(self, inputs):
        # implementation is ignored
        pass

2.导入该module

和导入model的方式类似,也分为两种方式:
一种是在

mmdet/models/necks/init.py

中导入信息

from .pafpn import PAFPN

或者添加下方代码到config文件

custom_imports = dict(
    imports=['mmdet.models.necks.pafpn.py'],
    allow_failed_imports=False)

3.更改config文件

neck=dict(
    type='PAFPN',
    in_channels=[256, 512, 1024, 2048],
    out_channels=256,
    num_outs=5)
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值