注册器设计模式,以SparseInst(detectron2)代码为例

注册器设计模式

from detectron2.utils.registry import Registry

这里面的Registry实际上就是注册器设计模式

1.举一个简单版的小例子来理解注册器设计模式

参考:Python注册器设计模式_python 注册类-CSDN博客

# 这一行代码是从Python的 typing 模块中导入了一些类型提示(type hint)相关的工具。
# Python的 typing 模块提供了静态类型检查的支持,这对提高代码的可读性和可靠性很有帮助
from typing import Any, Callable, Dict, List, Optional, Type, Union

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-ignore-all-errors[2,3]
from typing import Any, Dict, Iterable, Iterator, Tuple

from tabulate import tabulate


class Registry(Iterable[Tuple[str, Any]]):
    """
    The registry that provides name -> object mapping, to support third-party
    users' custom modules.

    To create a registry (e.g. a backbone registry):

    .. code-block:: python

        BACKBONE_REGISTRY = Registry('BACKBONE')

    To register an object:

    .. code-block:: python

        @BACKBONE_REGISTRY.register()
        class MyBackbone():
            ...

    Or:

    .. code-block:: python

        BACKBONE_REGISTRY.register(MyBackbone)
    """

    def __init__(self, name: str) -> None:
        """
        Args:
            name (str): the name of this registry
        """
        self._name: str = name
        self._obj_map: Dict[str, Any] = {}

    def _do_register(self, name: str, obj: Any) -> None:
        assert (
            name not in self._obj_map
        ), "An object named '{}' was already registered in '{}' registry!".format(
            name, self._name
        )
        self._obj_map[name] = obj

    def register(self, obj: Any = None) -> Any:
        """
        Register the given object under the the name `obj.__name__`.
        Can be used as either a decorator or not. See docstring of this class for usage.
        """
        if obj is None:
            # used as a decorator
            def deco(func_or_class: Any) -> Any:
                name = func_or_class.__name__
                self._do_register(name, func_or_class)
                return func_or_class

            return deco

        # used as a function call
        name = obj.__name__
        self._do_register(name, obj)

    def get(self, name: str) -> Any:
        ret = self._obj_map.get(name)
        if ret is None:
            raise KeyError(
                "No object named '{}' found in '{}' registry!".format(name, self._name)
            )
        return ret

    def __contains__(self, name: str) -> bool:
        return name in self._obj_map

    def __repr__(self) -> str:
        table_headers = ["Names", "Objects"]
        table = tabulate(
            self._obj_map.items(), headers=table_headers, tablefmt="fancy_grid"
        )
        return "Registry of {}:\n".format(self._name) + table

    def __iter__(self) -> Iterator[Tuple[str, Any]]:
        return iter(self._obj_map.items())

    def build(self, cfg: Dict[str, Any]) -> Any:
        if not isinstance(cfg, dict) or 'type' not in cfg:#检查字典是否符合规定,规定中,键值要写为 type
            raise TypeError('cfg must be a dict and contain the key "type"')
        module_type = cfg.pop('type')#获取模块类型
        module_cls = self.get(module_type)#获取模块类
        if module_cls is None:#如果模块不存在,报错
            raise KeyError(f'{module_type} is not registered in {self._name}')
        return module_cls(**cfg)#实例化模块并且返回
    # pyre-fixme[4]: Attribute must be annotated.
    __str__ = __repr__





# 用法示例
if __name__ == '__main__':
    # 第二步、定义一个注册表
    MODELS = Registry("SPARSE_INST_ENCODER")


    # 第三步、使用装饰器在注册表注册模块,比如一个类
    @MODELS.register()
    class ResNet:
        def __init__(self, depth):
            self.depth = depth

    # 使用装饰器注册模块,比如一个函数
    @MODELS.register()
    def resnet50():
        return ResNet(depth=50)

    # 使用普通函数注册模块,比如一个类,先声明一个类
    class MobileNet:
        def __init__(self, width_multiplier):
            self.width_multiplier = width_multiplier
    #注册到注册表中
    MODELS.register(MobileNet)#或者直接在MobileNet类声明上面用装饰器@MODELS.register_module()


    # 第四步、使用build函数,构建模型(其实是各个已注册模块的实例化)
    resnet = MODELS.build(dict(type='ResNet', depth=18))
    print(f'ResNet: depth = {resnet.depth}')  # 输出: ResNet: depth = 18

    mobilenet = MODELS.build(dict(type='MobileNet', width_multiplier=1.0))
    print(f'MobileNet: width_multiplier = {mobilenet.width_multiplier}')  # 输出: MobileNet: width_multiplier = 1.0

    resnet_instance = MODELS.build(dict(type='resnet50'))
    print(
        f'ResNet instance from function: depth = {resnet_instance.depth}')  # 输出: ResNet instance from function: depth = 50

注册器模式实现的四个步骤:

1、定义注册器类


注册器类比较关键,需要实现了好几个功能,各种模块的
注册:内部函数_do_register负责具体注册的实现;外部函数register暴漏给编码人员,写代码的时候用

储存:将被注册的模块(类、函数、等)存在注册器类的字典中。所以一般__init__() 里会初始化一个字典

获取:使用函数get,获取已注册对象,传入类的名称,返回这个类的实际实现的引用

实例化:创建build函数,实例化被注册的模块

2、初始化注册器

注册器又称注册表,创建一个注册器:用注册器类新建一个对象。
MODELS = Registry(‘models’)

3、注册可调用对象

将模块(类、函数、等)注册到注册器中去

可以用隐式的方法:装饰器挂在类的声明实现的头上,就可以完成注册了

@MODELS.register_module()
class ResNet:

如以上代码,能把ResNet类注册到注册表中。

也可以使用显示的注册方式,如

    # 使用普通函数注册模块,比如一个类,先声明一个类
    class MobileNet:
        def __init__(self, width_multiplier):
            self.width_multiplier = width_multiplier
    #注册到注册表中
    MODELS.register(MobileNet)#或者直接在MobileNet类声明上面用装饰器@MODELS.register_module()

这两种写法都是可以的,底层实现是一样的。

4、使用注册器构建对象

也就是已经被注册到注册表中的模块(类、函数、等)的实例化,实例化用的是build函数
比如可以这么写:

resnet_cfg = {'type': 'ResNet', 'depth': 18}
resnet = MODELS.build(resnet_cfg)

也可

resnet = MODELS.build(dict(type='ResNet', depth=18))

2.detectron2里的注册器设计模式

detectron2第一步定义注册器类的代码:

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-ignore-all-errors[2,3]
from typing import Any, Dict, Iterable, Iterator, Tuple

from tabulate import tabulate


class Registry(Iterable[Tuple[str, Any]]):
    """
    The registry that provides name -> object mapping, to support third-party
    users' custom modules.

    To create a registry (e.g. a backbone registry):

    .. code-block:: python

        BACKBONE_REGISTRY = Registry('BACKBONE')

    To register an object:

    .. code-block:: python

        @BACKBONE_REGISTRY.register()
        class MyBackbone():
            ...

    Or:

    .. code-block:: python

        BACKBONE_REGISTRY.register(MyBackbone)
    """

    def __init__(self, name: str) -> None:
        """
        Args:
            name (str): the name of this registry
        """
        self._name: str = name
        self._obj_map: Dict[str, Any] = {}

    def _do_register(self, name: str, obj: Any) -> None:
        assert (
            name not in self._obj_map
        ), "An object named '{}' was already registered in '{}' registry!".format(
            name, self._name
        )
        self._obj_map[name] = obj

    def register(self, obj: Any = None) -> Any:
        """
        Register the given object under the the name `obj.__name__`.
        Can be used as either a decorator or not. See docstring of this class for usage.
        """
        if obj is None:
            # used as a decorator
            def deco(func_or_class: Any) -> Any:
                name = func_or_class.__name__
                self._do_register(name, func_or_class)
                return func_or_class

            return deco

        # used as a function call
        name = obj.__name__
        self._do_register(name, obj)

    def get(self, name: str) -> Any:
        ret = self._obj_map.get(name)
        if ret is None:
            raise KeyError(
                "No object named '{}' found in '{}' registry!".format(name, self._name)
            )
        return ret

    def __contains__(self, name: str) -> bool:
        return name in self._obj_map

    def __repr__(self) -> str:
        table_headers = ["Names", "Objects"]
        table = tabulate(
            self._obj_map.items(), headers=table_headers, tablefmt="fancy_grid"
        )
        return "Registry of {}:\n".format(self._name) + table

    def __iter__(self) -> Iterator[Tuple[str, Any]]:
        return iter(self._obj_map.items())

    # pyre-fixme[4]: Attribute must be annotated.
    __str__ = __repr__

第二步初始化注册器的代码:

以初始化BACKBONE_REGISTRY注册器为例 

第三步注册可调用对象: 

第四步使用注册器构建对象:

实际上构建对象就只有一句话

 model = META_ARCH_REGISTRY.get(meta_arch)(cfg)

传入参数meta_arch的值为字字符串'SparseInst'。所以META_ARCH_REGISTRY.get(meta_arch)得到了SparseInst类名,META_ARCH_REGISTRY.get(meta_arch)(cfg)实际上就是SparseInst(cfg),也就是上面这句代码跟这句代码等价

model=SparseInst(cfg)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值