注册器类(Registry)
在mmdetection中,将会使用该类构建9个注册类实例,其实就是对类做一个划分管理.
比如,backbone 作为一族(vgg,resnet等)
文件:mmdetmodelsregistry.py
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
文件:mmdetdatasetsregistry.py
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')
每一个实例,都是存放属于这一簇的类,将来通过get key方式获取,key 来自于config文件.
mmdetection在构建模型的过程中,一直是通过key 去查找对应的类(在注册器中),找到对应的类,然后实例化,最终将配置描述的模型,构建出来.
举个栗子:
key = 'vgg'
VGG = BACKBONES.get(key)
key = 'bce'
BCE = LOSSES .get(key)
Registry 类
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import inspect
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
def __repr__(self):
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
self._name, list(self._module_dict.keys()))
return format_str
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def get(self, key):
return self._module_dict.get(key, None)
def _register_module(self, module_class):
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but got {}'.format(
type(module_class)))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class
def register_module(self, cls):
self._register_module(cls)
return cls
举个栗子:
在mmdetection的代码中,将一个类注册(插入)到(某一个)注册器里面,是直接写在类的声明上方.
ANIMAL = Registry('animal')
@ANIMAL.register_module
class Dog(object):
def __init__(self):
pass
def run(self):
print('running dog')
# ANIMAL.register_module(Dog)
dog = ANIMAL.get('Dog')
d = dog()
d.run()
等价写法:
ANIMAL = Registry('animal')
class Dog(object):
def __init__(self):
pass
def run(self):
print('running dog')
ANIMAL.register_module(Dog)
dog = ANIMAL.get('Dog')
d = dog()
d.run()
两者输出结果皆为:(不考虑python语言的特性,因为我不会)