在mmdetection中大量使用了如下的decorator(其本质是一种工厂模式),本文详细介绍这种模式
from ..builder import HEADS
@HEADS.register_module()
class RepPointLocHead(AnchorFreeHead):
...
Registry
Registry就是上一篇 3.2中描述的工厂模式的Manager,有两个作用:
- 注册(register):记录对象名字到对象class的映射
- 创建(build) :根据配置信息中给出的名字的参数创建对应class的对象
这两个功能分别由register_module和build两个函数来实现。
1. exmaple of Registry
对它的使用例子如下:
MODELS = Registry('models')
@MODELS.register_module()
class ResNet:
...
resnet = MODELS.build(dict(type='ResNet'))
2. register_module
上面例子代码的类比代码为
class ResNet:
...
ResNet = MODELS.register_module()(ResNet)
resnet = MODELS.build(dict(type='ResNet'))
从源代码可以看出register_module返回一个Decorator,它的输入是一个class,把class.__name__到class的映射记录在self._module_dict中,然后原封不动输出这个class
class Registry:
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
...
def _register_module(self, module_class, module_name=None, force=False):
...
if module_name is None:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
self._module_dict[name] = module_class
...
def register_module(self, name=None, force=False, module=None):
...
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
3. build
从源码可以看出build的实现可能来自于三处,按优先级排序为:
- 构造函数参数 build_func
- 构造函数参数 parent的parent.build_func
- 默认函数 build_from_cfg
class Registry:
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
self._scope = self.infer_scope() if scope is None else scope
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
...
def get(self, key):
scope, real_key = self.split_scope_key(key)
if scope is None or scope == self._scope:
# get from self
if real_key in self._module_dict:
return self._module_dict[real_key]
else:
...
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
其中build_from_cfg就在Registy同一个文件中定义,可以看到它就是根据cfg中定义的type字段在self._module_dict中映射得到对应的class,然后根据参数args进行调用。
# mmcv/utils/registry.py
def build_from_cfg(cfg, registry, default_args=None):
...
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
...
try:
return obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}')
4. 一个完整的例子
from ..builder import HEADS
@HEADS.register_module()
class RepPointLocHead(AnchorFreeHead):
model = build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
model.init_weights()
def build_detector(cfg, train_cfg=None, test_cfg=None):
...
return DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry
MODELS = Registry('models', parent=MMCV_MODELS)
BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS
DETECTORS = MODELS
上面的MMCV_MODELS来自于mmcv/cnn/builder.py中的MODELS
def build_model_from_cfg(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
MODELS = Registry('model', build_func=build_model_from_cfg)