mmdetection代码阅读系列(二):Decorators in mmdetection, Registry工厂模式

在mmdetection中大量使用了如下的decorator(其本质是一种工厂模式),本文详细介绍这种模式

from ..builder import HEADS
@HEADS.register_module()
class RepPointLocHead(AnchorFreeHead):
	...

Registry

Registry就是上一篇 3.2中描述的工厂模式的Manager,有两个作用:

  • 注册(register):记录对象名字到对象class的映射
  • 创建(build) :根据配置信息中给出的名字的参数创建对应class的对象

这两个功能分别由register_module和build两个函数来实现。

1. exmaple of Registry

对它的使用例子如下:

MODELS = Registry('models')
@MODELS.register_module()
class ResNet:
    ...
resnet = MODELS.build(dict(type='ResNet'))

2. register_module

上面例子代码的类比代码为

class ResNet:
    ...
ResNet = MODELS.register_module()(ResNet)
resnet = MODELS.build(dict(type='ResNet'))

从源代码可以看出register_module返回一个Decorator,它的输入是一个class,把class.__name__到class的映射记录在self._module_dict中,然后原封不动输出这个class

class Registry:
    def __init__(self, name, build_func=None, parent=None, scope=None):
        self._name = name
        ...
        
    def _register_module(self, module_class, module_name=None, force=False):
    	...
        if module_name is None:
            module_name = module_class.__name__
        if isinstance(module_name, str):
            module_name = [module_name]
        for name in module_name:
            if not force and name in self._module_dict:
                raise KeyError(f'{name} is already registered '
                               f'in {self.name}')
            self._module_dict[name] = module_class
	...
    def register_module(self, name=None, force=False, module=None):
        ...
        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls
        return _register

3. build

从源码可以看出build的实现可能来自于三处,按优先级排序为:

  • 构造函数参数 build_func
  • 构造函数参数 parent的parent.build_func
  • 默认函数 build_from_cfg
class Registry:
    def __init__(self, name, build_func=None, parent=None, scope=None):
        self._name = name
        self._module_dict = dict()
        self._children = dict()
        self._scope = self.infer_scope() if scope is None else scope

        # self.build_func will be set with the following priority:
        # 1. build_func
        # 2. parent.build_func
        # 3. build_from_cfg
        if build_func is None:
            if parent is not None:
                self.build_func = parent.build_func
            else:
                self.build_func = build_from_cfg
        else:
            self.build_func = build_func
        if parent is not None:
            assert isinstance(parent, Registry)
            parent._add_children(self)
            self.parent = parent
        else:
            self.parent = None
    ...
    def get(self, key):
        scope, real_key = self.split_scope_key(key)
        if scope is None or scope == self._scope:
            # get from self
            if real_key in self._module_dict:
                return self._module_dict[real_key]
        else:
            ...
                
    def build(self, *args, **kwargs):
        return self.build_func(*args, **kwargs, registry=self)

其中build_from_cfg就在Registy同一个文件中定义,可以看到它就是根据cfg中定义的type字段在self._module_dict中映射得到对应的class,然后根据参数args进行调用。

# mmcv/utils/registry.py
def build_from_cfg(cfg, registry, default_args=None):
	...
    args = cfg.copy()

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    obj_type = args.pop('type')
    if isinstance(obj_type, str):
        obj_cls = registry.get(obj_type)
        ...
    try:
        return obj_cls(**args)
    except Exception as e:
        # Normal TypeError does not print class name.
        raise type(e)(f'{obj_cls.__name__}: {e}')

4. 一个完整的例子

from ..builder import HEADS

@HEADS.register_module()
class RepPointLocHead(AnchorFreeHead):
model = build_detector(
    cfg.model,
    train_cfg=cfg.get('train_cfg'),
    test_cfg=cfg.get('test_cfg'))
model.init_weights()
def build_detector(cfg, train_cfg=None, test_cfg=None):
    ...
    return DETECTORS.build(
        cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry

MODELS = Registry('models', parent=MMCV_MODELS)

BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS
DETECTORS = MODELS

上面的MMCV_MODELS来自于mmcv/cnn/builder.py中的MODELS

def build_model_from_cfg(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)


MODELS = Registry('model', build_func=build_model_from_cfg)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值