mmdetection Model篇
mmdetection主要将模型分为5大类,分别为backbone(ResNet, MobileNet)、neck(FPN, PAFPN)、head(bbox prediction and mask prediction)、roi extractor(RoI Align)以及loss(FocalLoss, L1Loss, and GHMLoss)部分。
文章目录
一、Develop new components?
添加一个新的backbone需要三个步骤:
1. 定义一个新的backbone
mmdet/models/backbones/mobilenet.py 中创建一个新的文件
import torch.nn as nn
from ..builder import BACKBONES
@BACKBONES.register_module()
class MobileNet(nn.Module):
def __init__(self, arg1, arg2):
pass
def forward(self, x): # should return a tuple
pass
def init_weights(self, pretrained=None):
pass
2. 导入这个module
mmdetection给了两种导入方式,一种是在
mmdet/models/backbones/init.py 中使用
from .mobilenet import MobileNet
这里我们可以看一下当前init.py文件的一个配置情况:
from .darknet import Darknet
from .detectors_resnet import DetectoRS_ResNet
from .detectors_resnext import DetectoRS_ResNeXt
from .hourglass import HourglassNet
from .hrnet import HRNet
from .regnet import RegNet
from .res2net import Res2Net
from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1d
from .resnext import ResNeXt
from .ssd_vgg import SSDVGG
from .trident_resnet import TridentResNet
__all__ = [
'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net',
'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet',
'ResNeSt', 'TridentResNet'
]
或者尝试添加代码到config文件中避免更改原始的code
custom_imports = dict(
imports=['mmdet.models.backbones.mobilenet'],
allow_failed_imports=False)
3. 在自己的config文件中使用该backbone
model = dict(
...
backbone=dict(
type='MobileNet',
arg1=xxx,
arg2=xxx),
...
二、Add new necks
1.定义一个neck
修改文件位置在:
mmdet/models/necks/pafpn.py
from ..builder import NECKS
@NECKS.register_module()
class PAFPN(nn.Module):
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
add_extra_convs=False):
pass
def forward(self, inputs):
# implementation is ignored
pass
2.导入该module
和导入model的方式类似,也分为两种方式:
一种是在
mmdet/models/necks/init.py
中导入信息
from .pafpn import PAFPN
或者添加下方代码到config文件
custom_imports = dict(
imports=['mmdet.models.necks.pafpn.py'],
allow_failed_imports=False)
3.更改config文件
neck=dict(
type='PAFPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5)