class Registry:
"""A registry to map strings to classes.
Registered object could be built from registry.
Example:
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
>>> resnet = MODELS.build(dict(type='ResNet'))
Please refer to
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
advanced usage.
Args:
name (str): Registry name.
build_func(func, optional): Build function to construct instance from
Registry, func:`build_from_cfg` is used if neither ``parent`` or
``build_func`` is specified. If ``parent`` is specified and
``build_func`` is not given, ``build_func`` will be inherited
from ``parent``. Default: None.
parent (Registry, optional): Parent registry. The class registered in
children registry could be built from parent. Default: None.
scope (str, optional): The scope of registry. It is the key to search
for children registry. If not specified, scope will be the name of
the package where class is defined, e.g. mmdet, mmcls, mmseg.
Default: None.
"""
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
self._scope = self.infer_scope() if scope is None else scope
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
def __len__(self):
return len(self._module_dict)
def __contains__(self, key):
return self.get(key) is not None
def __repr__(self):
format_str = self.__class__.__name__ + \
f'(name={self._name}, ' \
f'items={self._module_dict})'
return format_str
@staticmethod
def infer_scope():
"""Infer the scope of registry.
The name of the package where registry is defined will be returned.
Example:
# in mmdet/models/backbone/resnet.py
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
The scope of ``ResNet`` will be ``mmdet``.
Returns:
scope (str): The inferred scope name.
"""
# inspect.stack() trace where this function is called, the index-2
# indicates the frame where `infer_scope()` is called
filename = inspect.getmodule(inspect.stack()[2][0]).__name__
split_filename = filename.split('.')
return split_filename[0]
@staticmethod
def split_scope_key(key):
"""Split scope and key.
The first scope will be split from key.
Examples:
>>> Registry.split_scope_key('mmdet.ResNet')
'mmdet', 'ResNet'
>>> Registry.split_scope_key('ResNet')
None, 'ResNet'
Return:
scope (str, None): The first scope.
key (str): The remaining key.
"""
split_index = key.find('.')
if split_index != -1:
return key[:split_index], key[split_index + 1:]
else:
return None, key
@property
def name(self):
return self._name
@property
def scope(self):
return self._scope
@property
def module_dict(self):
return self._module_dict
@property
def children(self):
return self._children
def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
scope, real_key = self.split_scope_key(key)
if scope is None or scope == self._scope:
# get from self
if real_key in self._module_dict:
return self._module_dict[real_key]
else:
# get from self._children
if scope in self._children:
return self._children[scope].get(real_key)
else:
# goto root
parent = self.parent
while parent.parent is not None:
parent = parent.parent
return parent.get(key)
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
def _add_children(self, registry):
"""Add children for a registry.
The ``registry`` will be added as children based on its scope.
The parent registry could build objects from children registry.
Example:
>>> models = Registry('models')
>>> mmdet_models = Registry('models', parent=models)
>>> @mmdet_models.register_module()
>>> class ResNet:
>>> pass
>>> resnet = models.build(dict(type='mmdet.ResNet'))
"""
assert isinstance(registry, Registry)
assert registry.scope is not None
assert registry.scope not in self.children, \
f'scope {registry.scope} exists in {self.name} registry'
self.children[registry.scope] = registry
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
self._module_dict[name] = module_class
def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.')
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)
return cls
def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# raise the error ahead of time
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
raise TypeError(
'name must be either of None, an instance of str or a sequence'
f' of str, but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
_register_module
方法
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
self._module_dict[name] = module_class
作用:
_register_module
是 Registry
类的一个私有方法,负责将类对象与类名(或指定的名字)映射到注册表中。这是类注册过程的核心逻辑。
参数说明:
-
module_class
:要注册的类对象。它必须是一个 Python 类。 -
module_name
:类在注册表中的名字,可以是字符串或字符串列表。如果未指定,默认使用类的名字。 -
force
:布尔值,指示是否允许覆盖已存在的同名类。如果为False
,且_module_dict
中已存在同名类,方法将抛出KeyError
。
逻辑详解:
-
类型检查: 使用
inspect.isclass
检查module_class
是否是一个类。如果不是,抛出TypeError
,并提示错误信息。 -
确定模块名称: 如果
module_name
未提供(即为None
),则使用类的名字(module_class.__name__
)作为默认名称。 如果module_name
是一个字符串,它会被转换为一个包含单个元素的列表,这样方便后续处理多个名字的情况。 -
注册过程: 遍历
module_name
列表中的每个名称:- 如果
force
为False
,且_module_dict
中已存在该名字,抛出KeyError
,提示名字已存在。 - 如果
force
为True
,或者名字不存在于_module_dict
中,则将module_class
注册到_module_dict
中,键为module_name
中的名字,值为module_class
。
- 如果
deprecated_register_module
方法
def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.')
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)
return cls
作用:
deprecated_register_module
方法提供了对旧的类注册 API 的兼容性支持。它允许用户使用较旧的方式来注册类,并在使用时发出警告,建议迁移到新的 API。
参数说明:
-
cls
:要注册的类对象。如果为None
,方法返回一个部分应用的函数,该函数在调用时会传入类对象。 -
force
:布尔值,指示是否允许覆盖已存在的同名类。
逻辑详解:
-
警告: 方法首先发出一个弃用警告,提示用户旧的 API 将被移除,并建议使用新的
register_module
方法。 -
返回部分应用函数: 如果
cls
为None
,方法使用functools.partial
返回一个部分应用的函数,该函数封装了deprecated_register_module
,但force
参数已经绑定。这样用户可以稍后传入类对象。 -
注册类: 如果
cls
不为None
,方法直接调用_register_module
方法将cls
注册到注册表中。 -
返回类: 注册完成后,方法返回
cls
类对象本身。
register_module
方法
def register_module(self, name=None, force=False, module=None):
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
raise TypeError(
'name must be either of None, an instance of str or a sequence'
f' of str, but got {type(name)}')
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
作用:
register_module
方法是 Registry
类的一个公共方法,用于注册类到注册表中。它可以作为装饰器或普通方法使用,为用户提供了灵活的类注册接口。
参数说明:
-
name
:注册的类的名称或名称列表。可以是None
(使用类的默认名称)、字符串,或字符串的列表。 -
force
:布尔值,指示是否允许覆盖已存在的同名类。如果为False
,且_module_dict
中已存在同名类,方法将抛出KeyError
。 -
module
:需要注册的类对象。如果提供了module
,方法会立即将其注册到注册表中。
逻辑详解:
-
参数验证:
- 首先,检查
force
是否为布尔值。如果不是,抛出TypeError
。 - 接下来,检查
name
是否为一个类对象。如果是,调用旧的注册方法deprecated_register_module
进行注册,以保持兼容性。 - 然后,检查
name
是否为None
、字符串或字符串列表。如果不是,抛出TypeError
。
- 首先,检查
-
直接注册类:
- 如果
module
不为None
,方法将module
和name
传递给_register_module
进行注册。 - 注册完成后,方法返回
module
类对象本身。
- 如果
-
作为装饰器使用:
- 如果
module
为None
,方法返回一个内部的_register
函数,这个函数会作为装饰器使用。 _register
函数接受一个类作为参数,将其与name
一起传递给_register_module
进行注册。- 装饰器的效果是:被装饰的类会被自动注册到注册表中,注册完成后,返回该类对象本身。
- 如果
使用示例:
-
作为装饰器使用:
backbones = Registry('backbone') @backbones.register_module() class ResNet: pass
在这个示例中,
ResNet
类被注册到backbones
注册表中。register_module
方法作为装饰器使用,不需要显式地调用_register_module
。 -
指定名称的装饰器:
backbones = Registry('backbone') @backbones.register_module(name='mnet') class MobileNet: pass
在这个示例中,
MobileNet
类被注册为mnet
,而不是其默认的类名。这样在注册表中可以通过'mnet'
来引用这个类。 -
作为普通方法调用:
backbones = Registry('backbone') class ResNet: pass backbones.register_module(module=ResNet)
这种方式是直接调用
register_module
方法,将ResNet
类注册到backbones
注册表中。
总结
这三个方法共同实现了 Registry
类的核心功能,允许用户将类注册到注册表中,并通过简单的字符串来引用和使用这些类。_register_module
是底层的注册逻辑,deprecated_register_module
提供了对旧的 API 的兼容性支持,而 register_module
是一个高层次的接口,提供了灵活的类注册方式。
(调试BEVformer的一些感想)