本文主要记录 maskrcnn_benckmark 中一个非常有用的 utility,Registry()类的研究笔记
文章目录
1. Registry()的实现
在{ROOT_DIR}/maskrcnn_benchmark/utils/registry
中定义Registry()
.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
def _register_generic(module_dict, module_name, module):
assert module_name not in module_dict
module_dict[module_name] = module
class Registry(dict):
'''
A helper class for managing registering modules, it extends a dictionary
and provides a register functions.
Eg. creeting a registry:
some_registry = Registry({"default": default_module})
There're two ways of registering new modules:
1): normal way is just calling register function:
def foo():
...
some_registry.register("foo_module", foo)
2): used as decorator when declaring the module:
@some_registry.register("foo_module")
@some_registry.register("foo_modeul_nickname")
def foo():
...
Access of module is just like using a dictionary, eg:
f = some_registry["foo_modeul"]
'''
def __init__(self, *args, **kwargs):
super(Registry, self).__init__(*args, **kwargs)
def register(self, module_name, module=None):
# used as function call
if module is not None:
_register_generic(self, module_name, module)
return
# used as decorator
def register_fn(fn):
_register_generic(self, module_name, fn)
return fn
return register_fn
Registry()
继承自Python的内建类型dict()
。因此,本质上一个Registry()
实例其实就是一个字典(dict()
),它在字典的基础上添加了一个类方法register(self, module_name, module=None)
。这个类方法本质上是一种添加字典键值对的方法,它有两种使用方式:
- 直接调用类方法
此时module不能为None.
module_name既是dict()中的key,module既是dict()中的value. - 作为装饰器使用
此时module必须为None.
module_name既是dict()中的key,module从装饰器接受函数对象或类对象作为value.
在maskrcnn_benckmark
中,Registry()
主要用来管理类和函数。
2. Registry()类方法register的使用
In[2]: from maskrcnn_benchmark.utils.registry import Registry
In[3]: TEST_REGISTRY = Registry()
In[4]: TEST_REGISTRY
Out[4]: {}
2.1 直接调用类方法
In[5]: TEST_REGISTRY.register('1', 1)
In[6]: TEST_REGISTRY
Out[6]: {'1': 1}
In[7]: def func1(flag):
...: if flag == 0:
...: return
...: if flag == 1:
...: print('calling func1')
...:
In[8]: TEST_REGISTRY.register('func1', func_1)
In[9]: TEST_REGISTRY['func1']
Out[9]: <function __main__.func_1(flag)>
In[10]: TEST_REGISTRY['func1'](0)
In[11]: TEST_REGISTRY['func1'](1)
calling func1
In[12]: TEST_REGISTRY
Out[12]: {'1': 1, 'func1': <function __main__.func1(flag)>}
2.2 作为装饰器使用
In[13]: @TEST_REGISTRY.register('func2')
...: def func2(flag):
...: if flag == 0:
...: return
...: if flag == 1:
...: print('calling func2')
...:
In[14]: TEST_REGISTRY['func2']
Out[14]: <function __main__.func2(flag)>
In[15]: TEST_REGISTRY['func2'](0)
In[16]: TEST_REGISTRY['func2'](1)
calling func2
In[17]: @TEST_REGISTRY.register('Class1')
...: class Class1(object):
...: def __init__(self):
...: print('calling Class1')
...:
In[18]: TEST_REGISTRY['Class1']
Out[18]: __main__.Class1
In[19]: TEST_REGISTRY['Class1']()
calling Class1
Out[19]: <__main__.Class1 at 0x7f6bfda056d8>
In[20]: instance1 = TEST_REGISTRY['Class1']()
calling Class1
In[21]: type(instance1)
Out[21]: __main__.Class1
嵌套使用,多个key对应同一个value
In[23]: @TEST_REGISTRY.register('func3-1')
...: @TEST_REGISTRY.register('func3-2')
...: @TEST_REGISTRY.register('func3-3')
...: def func3():
...: print('calling func3')
...:
In[24]: TEST_REGISTRY['func3-1']()
calling func3
In[25]: TEST_REGISTRY['func3-2']()
calling func3
In[26]: TEST_REGISTRY['func3-3']()
calling func3
3. Registry()在maskrcnn_benckmark中的使用
在maskrcnn_benckmark
中,Registry()
主要用于辅助yacs
配置文件系统,帮助管理模型的组件。
首先在{ROOT_DIR}/maskrcnn_benchmark/modeling/registry.py
中创建实例:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from maskrcnn_benchmark.utils.registry import Registry
BACKBONES = Registry()
RPN_HEADS = Registry()
ROI_BOX_FEATURE_EXTRACTORS = Registry()
ROI_BOX_PREDICTOR = Registry()
ROI_KEYPOINT_FEATURE_EXTRACTORS = Registry()
ROI_KEYPOINT_PREDICTOR = Registry()
ROI_MASK_FEATURE_EXTRACTORS = Registry()
ROI_MASK_PREDICTOR = Registry()
在下列各文件中以装饰器方式调用类方法register(self, module_name, module=None)
:
{ROOT_DIR}/maskrcnn_benchmark/modeling/backbone/backbone.py
{ROOT_DIR}/maskrcnn_benchmark/modeling/rpn/rpn.py
{ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py
{ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py
{ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_feature_extractors.py
{ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_predictors.py
{ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py
{ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py
在上述文件中将{ROOT_DIR}/maskrcnn_benchmark/modeling/registry.py
作为module导入:
In[2]: from maskrcnn_benchmark.modeling import registry
In[3]: type(registry)
Out[3]: module
In[4]: registry.BACKBONES
Out[4]: {}
以下是导入后,完成定义的结果:
In[5]: from maskrcnn_benchmark.modeling.backbone.backbone import registry
In[6]: registry.BACKBONES
Out[6]:
{'R-101-C5': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_backbone(cfg)>,
'R-101-C4': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_backbone(cfg)>,
'R-50-C5': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_backbone(cfg)>,
'R-50-C4': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_backbone(cfg)>,
'R-152-FPN': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_backbone(cfg)>,
'R-101-FPN': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_backbone(cfg)>,
'R-50-FPN': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_backbone(cfg)>,
'R-101-FPN-RETINANET': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_p3p7_backbone(cfg)>,
'R-50-FPN-RETINANET': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_p3p7_backbone(cfg)>,
'FBNet': <function maskrcnn_benchmark.modeling.backbone.fbnet.add_conv_body(cfg, dim_in=3)>}
In[7]: registry.RPN_HEADS
Out[7]:
{'SingleConvRPNHead': maskrcnn_benchmark.modeling.rpn.rpn.RPNHead,
'FBNet.rpn_head': <function maskrcnn_benchmark.modeling.backbone.fbnet.add_rpn_head(cfg, in_channels, num_anchors)>}