我已经在conda虚拟环境中安装好了mmdet库,自定义模型的模型如下:
import torch
import torch.nn as nn
from mmdet.models import DETECTORS
@DETECTORS.register_module()
class A(nn.Module):
def __init__(self, train_cfg, test_cfg):
super(A, self).__init__()
self.a = nn.Parameter(torch.randn(3, 4))
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def forward(self):
return self.a
通过调用DETECTORS.register_module()方法实现模型的注册,
我的配置文件如下,
model = dict(
type='B',
train_cfg=dict(),
test_cfg=dict()
)
在主函数中通过配置文件实例化model,
from mmcv import Config
from mmdet3d.models import build_model
import os
if __name__ == '__main__':
cfg = Config.fromfile('cfg.py')
model = build_model(cfg.model)
此时mmcv库报错:
KeyError: 'A is not in the models registry'
修改后的代码如下:
from mmcv import Config
from mmdet3d.models import build_model
import os
from temp.model import A
if __name__ == '__main__':
cfg = Config.fromfile('cfg.py')
model = build_model(cfg.model)
也就是说在文件中导入class A时,装饰器函数才会生效,模块注册才完成。