# demo/registry_import_all.py
from demo.registry_ssd import Toy_SSD
from demo.registry_yolo import Toy_Yolo
类的装饰器在文件导入的时候会起作用,会对其进行装饰函数的内容(注册:在registry_model内,设置了从str到类的映射)
将所有的导入工作放入到一个文件中,使用时只需要导入这一个文件,即导入了所有的模型文件,完成所有模型的注册
使用的是类的对象的方法去装饰类本身,被修饰类将作为registry_model的参数传入到装饰函数中去
# demo/registry_ssd.py
from registry_demo_root import MODEL
import torch.nn as nn
@MODEL.registry_model
class Toy_SSD(nn.Module):
def __init__(self, input, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
print(f"model of input : {input}")
self.input = input
def forward(self, x):
return x * self.input
# demo/registry_yolo.py
from registry_demo_root import MODEL
import torch.nn as nn
@MODEL.registry_model
class Toy_Yolo(nn.Module):
def __init__(self, input, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
print(f"model of input : {input}")
self.input = input
def forward(self, x):
return x * self.input
根注册器:本质为注册器的一个对象
在注册模型时,导入的MODEL是统一的一个对象,其id唯一,使得注册模型时,使用的是同一个对象的_modele_dict属性,维护了同一个从str到类的映射
# demo/registry_demo_root.py
from typing import Type
class Registry:
def __init__(self) -> None:
self._modele_dict = dict()
def registry_model(self, model: Type):
self._modele_dict[model.__name__] = model
return model
def get(self, model_str: str):
return self._modele_dict[model_str]
MODEL = Registry()
在使用模型时,使用的也是同一个MODEL对象,,调用该对象的get方法,从该对象的model_dict中获取该类作为返回值
# demo/use_model.py
import registry_import_all # 将所有的模型都注册到root.py文件中的MODEL对象(id唯一)的_modele_dict字典中去
from registry_demo_root import MODEL # 从root.py中导入该对象(该对象的id唯一)
registry = MODEL # 两者的id相同
model_str = 'Toy_SSD'
obj_cls = registry.get(model_str) # 获取从字符串到类的映射
model = obj_cls(8888) # 实例化一个对象
result = model(222) # 模型的前向传播
print(f"the model is ok, the foward fun's output is {result}")