BEVformer一些自己的理解


class Registry:
    """A registry to map strings to classes.

    Registered object could be built from registry.
    Example:
        >>> MODELS = Registry('models')
        >>> @MODELS.register_module()
        >>> class ResNet:
        >>>     pass
        >>> resnet = MODELS.build(dict(type='ResNet'))

    Please refer to
    https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
    advanced usage.

    Args:
        name (str): Registry name.
        build_func(func, optional): Build function to construct instance from
            Registry, func:`build_from_cfg` is used if neither ``parent`` or
            ``build_func`` is specified. If ``parent`` is specified and
            ``build_func`` is not given,  ``build_func`` will be inherited
            from ``parent``. Default: None.
        parent (Registry, optional): Parent registry. The class registered in
            children registry could be built from parent. Default: None.
        scope (str, optional): The scope of registry. It is the key to search
            for children registry. If not specified, scope will be the name of
            the package where class is defined, e.g. mmdet, mmcls, mmseg.
            Default: None.
    """

    def __init__(self, name, build_func=None, parent=None, scope=None):
        self._name = name
        self._module_dict = dict()
        self._children = dict()
        self._scope = self.infer_scope() if scope is None else scope

        # self.build_func will be set with the following priority:
        # 1. build_func
        # 2. parent.build_func
        # 3. build_from_cfg
        if build_func is None:
            if parent is not None:
                self.build_func = parent.build_func
            else:
                self.build_func = build_from_cfg
        else:
            self.build_func = build_func
        if parent is not None:
            assert isinstance(parent, Registry)
            parent._add_children(self)
            self.parent = parent
        else:
            self.parent = None

    def __len__(self):
        return len(self._module_dict)

    def __contains__(self, key):
        return self.get(key) is not None

    def __repr__(self):
        format_str = self.__class__.__name__ + \
                     f'(name={self._name}, ' \
                     f'items={self._module_dict})'
        return format_str

    @staticmethod
    def infer_scope():
        """Infer the scope of registry.

        The name of the package where registry is defined will be returned.

        Example:
            # in mmdet/models/backbone/resnet.py
            >>> MODELS = Registry('models')
            >>> @MODELS.register_module()
            >>> class ResNet:
            >>>     pass
            The scope of ``ResNet`` will be ``mmdet``.


        Returns:
            scope (str): The inferred scope name.
        """
        # inspect.stack() trace where this function is called, the index-2
        # indicates the frame where `infer_scope()` is called
        filename = inspect.getmodule(inspect.stack()[2][0]).__name__
        split_filename = filename.split('.')
        return split_filename[0]

    @staticmethod
    def split_scope_key(key):
        """Split scope and key.

        The first scope will be split from key.

        Examples:
            >>> Registry.split_scope_key('mmdet.ResNet')
            'mmdet', 'ResNet'
            >>> Registry.split_scope_key('ResNet')
            None, 'ResNet'

        Return:
            scope (str, None): The first scope.
            key (str): The remaining key.
        """
        split_index = key.find('.')
        if split_index != -1:
            return key[:split_index], key[split_index + 1:]
        else:
            return None, key

    @property
    def name(self):
        return self._name

    @property
    def scope(self):
        return self._scope

    @property
    def module_dict(self):
        return self._module_dict

    @property
    def children(self):
        return self._children

    def get(self, key):
        """Get the registry record.

        Args:
            key (str): The class name in string format.

        Returns:
            class: The corresponding class.
        """
        scope, real_key = self.split_scope_key(key)
        if scope is None or scope == self._scope:
            # get from self
            if real_key in self._module_dict:
                return self._module_dict[real_key]
        else:
            # get from self._children
            if scope in self._children:
                return self._children[scope].get(real_key)
            else:
                # goto root
                parent = self.parent
                while parent.parent is not None:
                    parent = parent.parent
                return parent.get(key)

    def build(self, *args, **kwargs):
        return self.build_func(*args, **kwargs, registry=self)

    def _add_children(self, registry):
        """Add children for a registry.

        The ``registry`` will be added as children based on its scope.
        The parent registry could build objects from children registry.

        Example:
            >>> models = Registry('models')
            >>> mmdet_models = Registry('models', parent=models)
            >>> @mmdet_models.register_module()
            >>> class ResNet:
            >>>     pass
            >>> resnet = models.build(dict(type='mmdet.ResNet'))
        """

        assert isinstance(registry, Registry)
        assert registry.scope is not None
        assert registry.scope not in self.children, \
            f'scope {registry.scope} exists in {self.name} registry'
        self.children[registry.scope] = registry

    def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if isinstance(module_name, str):
            module_name = [module_name]
        for name in module_name:
            if not force and name in self._module_dict:
                raise KeyError(f'{name} is already registered '
                               f'in {self.name}')
            self._module_dict[name] = module_class

    def deprecated_register_module(self, cls=None, force=False):
        warnings.warn(
            'The old API of register_module(module, force=False) '
            'is deprecated and will be removed, please use the new API '
            'register_module(name=None, force=False, module=None) instead.')
        if cls is None:
            return partial(self.deprecated_register_module, force=force)
        self._register_module(cls, force=force)
        return cls

    def register_module(self, name=None, force=False, module=None):
        """Register a module.

        A record will be added to `self._module_dict`, whose key is the class
        name or the specified name, and value is the class itself.
        It can be used as a decorator or a normal function.

        Example:
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module()
            >>> class ResNet:
            >>>     pass

            >>> backbones = Registry('backbone')
            >>> @backbones.register_module(name='mnet')
            >>> class MobileNet:
            >>>     pass

            >>> backbones = Registry('backbone')
            >>> class ResNet:
            >>>     pass
            >>> backbones.register_module(ResNet)

        Args:
            name (str | None): The module name to be registered. If not
                specified, the class name will be used.
            force (bool, optional): Whether to override an existing class with
                the same name. Default: False.
            module (type): Module class to be registered.
        """
        if not isinstance(force, bool):
            raise TypeError(f'force must be a boolean, but got {type(force)}')
        # NOTE: This is a walkaround to be compatible with the old api,
        # while it may introduce unexpected bugs.
        if isinstance(name, type):
            return self.deprecated_register_module(name, force=force)

        # raise the error ahead of time
        if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
            raise TypeError(
                'name must be either of None, an instance of str or a sequence'
                f'  of str, but got {type(name)}')

        # use it as a normal method: x.register_module(module=SomeClass)
        if module is not None:
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

        return _register

_register_module 方法

def _register_module(self, module_class, module_name=None, force=False):
    if not inspect.isclass(module_class):
        raise TypeError('module must be a class, '
                        f'but got {type(module_class)}')

    if module_name is None:
        module_name = module_class.__name__
    if isinstance(module_name, str):
        module_name = [module_name]
    for name in module_name:
        if not force and name in self._module_dict:
            raise KeyError(f'{name} is already registered '
                           f'in {self.name}')
        self._module_dict[name] = module_class
作用:

_register_moduleRegistry 类的一个私有方法,负责将类对象与类名(或指定的名字)映射到注册表中。这是类注册过程的核心逻辑。

参数说明:
  • module_class:要注册的类对象。它必须是一个 Python 类。

  • module_name:类在注册表中的名字,可以是字符串或字符串列表。如果未指定,默认使用类的名字。

  • force:布尔值,指示是否允许覆盖已存在的同名类。如果为 False,且 _module_dict 中已存在同名类,方法将抛出 KeyError

逻辑详解:
  1. 类型检查: 使用 inspect.isclass 检查 module_class 是否是一个类。如果不是,抛出 TypeError,并提示错误信息。

  2. 确定模块名称: 如果 module_name 未提供(即为 None),则使用类的名字(module_class.__name__)作为默认名称。 如果 module_name 是一个字符串,它会被转换为一个包含单个元素的列表,这样方便后续处理多个名字的情况。

  3. 注册过程: 遍历 module_name 列表中的每个名称:

    • 如果 forceFalse,且 _module_dict 中已存在该名字,抛出 KeyError,提示名字已存在。
    • 如果 forceTrue,或者名字不存在于 _module_dict 中,则将 module_class 注册到 _module_dict 中,键为 module_name 中的名字,值为 module_class

deprecated_register_module 方法

def deprecated_register_module(self, cls=None, force=False):
    warnings.warn(
        'The old API of register_module(module, force=False) '
        'is deprecated and will be removed, please use the new API '
        'register_module(name=None, force=False, module=None) instead.')
    if cls is None:
        return partial(self.deprecated_register_module, force=force)
    self._register_module(cls, force=force)
    return cls
作用:

deprecated_register_module 方法提供了对旧的类注册 API 的兼容性支持。它允许用户使用较旧的方式来注册类,并在使用时发出警告,建议迁移到新的 API。

参数说明:
  • cls:要注册的类对象。如果为 None,方法返回一个部分应用的函数,该函数在调用时会传入类对象。

  • force:布尔值,指示是否允许覆盖已存在的同名类。

逻辑详解:
  1. 警告: 方法首先发出一个弃用警告,提示用户旧的 API 将被移除,并建议使用新的 register_module 方法。

  2. 返回部分应用函数: 如果 clsNone,方法使用 functools.partial 返回一个部分应用的函数,该函数封装了 deprecated_register_module,但 force 参数已经绑定。这样用户可以稍后传入类对象。

  3. 注册类: 如果 cls 不为 None,方法直接调用 _register_module 方法将 cls 注册到注册表中。

  4. 返回类: 注册完成后,方法返回 cls 类对象本身。

register_module 方法

def register_module(self, name=None, force=False, module=None):
    if not isinstance(force, bool):
        raise TypeError(f'force must be a boolean, but got {type(force)}')

    if isinstance(name, type):
        return self.deprecated_register_module(name, force=force)

    if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
        raise TypeError(
            'name must be either of None, an instance of str or a sequence'
            f'  of str, but got {type(name)}')

    if module is not None:
        self._register_module(
            module_class=module, module_name=name, force=force)
        return module

    def _register(cls):
        self._register_module(
            module_class=cls, module_name=name, force=force)
        return cls

    return _register
作用:

register_module 方法是 Registry 类的一个公共方法,用于注册类到注册表中。它可以作为装饰器或普通方法使用,为用户提供了灵活的类注册接口。

参数说明:
  • name:注册的类的名称或名称列表。可以是 None(使用类的默认名称)、字符串,或字符串的列表。

  • force:布尔值,指示是否允许覆盖已存在的同名类。如果为 False,且 _module_dict 中已存在同名类,方法将抛出 KeyError

  • module:需要注册的类对象。如果提供了 module,方法会立即将其注册到注册表中。

逻辑详解:
  1. 参数验证

    • 首先,检查 force 是否为布尔值。如果不是,抛出 TypeError
    • 接下来,检查 name 是否为一个类对象。如果是,调用旧的注册方法 deprecated_register_module 进行注册,以保持兼容性。
    • 然后,检查 name 是否为 None、字符串或字符串列表。如果不是,抛出 TypeError
  2. 直接注册类

    • 如果 module 不为 None,方法将 modulename 传递给 _register_module 进行注册。
    • 注册完成后,方法返回 module 类对象本身。
  3. 作为装饰器使用

    • 如果 moduleNone,方法返回一个内部的 _register 函数,这个函数会作为装饰器使用。
    • _register 函数接受一个类作为参数,将其与 name 一起传递给 _register_module 进行注册。
    • 装饰器的效果是:被装饰的类会被自动注册到注册表中,注册完成后,返回该类对象本身。
使用示例:
  1. 作为装饰器使用

    backbones = Registry('backbone')
    
    @backbones.register_module()
    class ResNet:
        pass
    

    在这个示例中,ResNet 类被注册到 backbones 注册表中。register_module 方法作为装饰器使用,不需要显式地调用 _register_module

  2. 指定名称的装饰器

    backbones = Registry('backbone')
    
    @backbones.register_module(name='mnet')
    class MobileNet:
        pass
    

    在这个示例中,MobileNet 类被注册为 mnet,而不是其默认的类名。这样在注册表中可以通过 'mnet' 来引用这个类。

  3. 作为普通方法调用

    backbones = Registry('backbone')
    class ResNet:
        pass
    
    backbones.register_module(module=ResNet)
    

    这种方式是直接调用 register_module 方法,将 ResNet 类注册到 backbones 注册表中。

总结

这三个方法共同实现了 Registry 类的核心功能,允许用户将类注册到注册表中,并通过简单的字符串来引用和使用这些类。_register_module 是底层的注册逻辑,deprecated_register_module 提供了对旧的 API 的兼容性支持,而 register_module 是一个高层次的接口,提供了灵活的类注册方式。

(调试BEVformer的一些感想)

  • 8
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值