一、注册器(registry)介绍
MMEngine 实现的注册器可以看作一个映射表
和模块构建方法(build function)
的组合。
映射表维护了一个字符串(str)
到类(class)
或者函数(function)
的映射,使得用户可以借助字符串查找到相应的类或函数,例如维护字符串 “ResNet” 到 ResNet 类或函数的映射,使得用户可以通过 “ResNet” 找到 ResNet 类;
而模块构建方法则定义了如何根据字符串(str)
查找到对应的类(class)
或者函数(function)
以及如何实例化
这个类或者调用这个函数,例如,通过字符串 “bn” 找到 nn.BatchNorm2d 并实例化 BatchNorm2d 模块;又或者通过字符串 “build_batchnorm2d” 找到 build_batchnorm2d 函数并返回该函数的调用结果。MMEngine 中的注册器默认使用 build_from_cfg 函数来查找并实例化字符串对应的类或者函数。
一个注册器管理的类或函数通常有相似的接口和功能,因此该注册器可以被视作这些类或函数的抽象
。例如注册器 MODELS 可以被视作所有模型的抽象,管理了 ResNet,SEResNet 和 RegNetX 等分类网络的类以及 build_ResNet, build_SEResNet 和 build_RegNetX 等分类网络的构建函数。
二、 入门用法 - 注册 class
使用注册器管理代码库中的模块,需要以下三个步骤。
创建注册器
创建一个用于实例化类的构建方法(可选,在大多数情况下可以只使用默认方法)
将模块加入注册器中
假设我们要实现一系列激活模块并且希望仅修改配置就能够使用不同的激活模块而无需修改代码。
2.1 首先创建注册器
from mmengine import Registry
# scope 表示注册器的作用域,如果不设置,默认为包名,例如在 mmdetection 中,它的 scope 为 mmdet
# locations 表示注册在此注册器的模块所存放的位置,注册器会根据预先定义的位置在构建模块时自动 import
MODELS = Registry('activation', scope='mmengine', locations=['mmengine.models.activations'])
ACTIVATION = MODELS
ACTIVATION_SHUAI = MODELS
locations
指定的模块 mmengine.models.activations
对应了 mmengine/models/activations.py
文件。
在使用注册器构建模块的时候,ACTIVATION
注册器会自动从该文件中导入实现的模块。
因此,我们可以在 mmengine/models/activations.py
文件中实现不同的激活函数
,例如 Sigmoid,ReLU 和 Softmax
。
import torch.nn as nn
# Sigmoid
# 使用注册器管理模块
@ACTIVATION.register_module()
class Sigmoid_Shuai(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
print('call Sigmoid.forward')
return x
# ReLU
@ACTIVATION_SHUAI.register_module()
class ReLU_Shuai(nn.Module):
def __init__(self, inplace=False):
super().__init__()
def forward(self, x):
print('call ReLU.forward')
return x
# Softmax
@ACTIVATION.register_module()
class Softmax_Shuai(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
print('call Softmax.forward')
return x
使用注册器管理模块的关键步骤
是,将实现的模块注册到注册表 ACTIVATION 中
。通过 @ACTIVATION.register_module() 装饰
所实现的模块,字符串和类或函数之间的映射就可以由 ACTIVATION 构建和维护
,我们也可以通过 ACTIVATION.register_module(module=ReLU) 实现同样的功能。
通过
注册(register)
,我们就可以通过 ACTIVATION 建立字符串(str)
到类(class)
或者函数(function)
之间的映射(mapping)
print(ACTIVATION.module_dict)
print(ACTIVATION_SHUAI.module_dict)
只有模块所在的文件被导入时,注册机制才会被触发,用户可以通过三种方式将模块添加到注册器中:
1.在 locations 指向的文件中实现模块。注册器将自动在预先定义的位置导入模块。
这种方式是为了简化算法库的使用,以便用户可以直接使用 REGISTRY.build(cfg)。
2. 手动导入文件。常用于用户在算法库之内或之外实现新的模块。
3. 在配置中使用 custom_imports 字段。 详情请参考导入自定义Python模块。
2.2 使用模块
2.2.1 默认构建流程
模块成功注册后,我们可以通过配置文件使用这个激活模块。
# 模块成功注册后,我们可以通过配置文件使用这个激活模块。
# 配置文件(字符串)
sigmoid_act_cfg = dict(type='Sigmoid_Shuai')
# 调用 Sigmoid.__init__
sigmoid_activation = ACTIVATION.build(sigmoid_act_cfg)
# 调用 Sigmoid.forward
import torch
input = torch.randn(2)
print(sigmoid_activation(input))
2.2.2 自定义构建流程
如果我们希望在创建实例前检查输入参数的类型(或者任何其他操作),我们可以实现一个构建方法并将其传递给注册器从而实现自定义构建流程。
# 自定义,相当于build_loss, build_model
def build_activation(cfg, registry, *args, **kwargs):
cfg_ = cfg.copy()
act_type = cfg_.pop('type')
print('你可以在这里做一些自己的操作!')
print('2024//9/30 by SHUAI')
print('自动驾驶算法岗位找人,有意愿的可以私聊!')
# 打印...类型
print(f'build your activation: {act_type}')
act_cls = registry.get(act_type)
# 调用 Sigmoid.__init__
sigmoid_activation = act_cls(*args, **kwargs, **cfg_)
return sigmoid_activation
# 配置文件(字符串)
sigmoid_act_cfg = dict(type='Sigmoid_Shuai')
# 默认构建流程
# sigmoid_activation = ACTIVATION.build(sigmoid_act_cfg)
# 自定义构建流程:增加自己的一些操作
sigmoid_activation = build_activation(sigmoid_act_cfg, ACTIVATION)
import torch
input = torch.randn(2)
print(sigmoid_activation(input))
三、 入门用法 - 注册 function
FUNCTION = Registry('SHUAI')
@FUNCTION.register_module()
def print_SHUAI():
print("2024/9/30 by SHUAI")
func_cfg = dict(type='print')
func_res = FUNCTION.build(func_cfg)
注意这里配置文件中的字符串(str)
func_cfg = dict(type=‘print’)的 type=‘print’
和@FUNCTION.register_module()装饰的 print_SHUAI不一致,会报错。
修改如下:
cfg = dict(type='print_SHUAI')
func_res = FUNCTION.build(cfg)
输出: