paddle detection 配置文件怎么实例化的 代码梳理 -----(regiester)

继续上一篇我们讲了通过yaml配置文件实例化类,其中配置文件要加上“!”这个符号,如果没有这个符号的其他类是怎么注册实例化的呢?

我们看到代码workspace.py中

def register(cls):
    """
    Register a given module class.

    Args:
        cls (type): Module class to be registered.

    Returns: cls
    """
    if cls.__name__ in global_config:
        raise ValueError("Module class already registered: {}".format(
            cls.__name__))
    if hasattr(cls, '__op__'):
        cls = make_partial(cls)
    global_config[cls.__name__] = extract_schema(cls)
    return cls

这个函数就是每个类的装饰器,他干啥了呢,一句话就是他把类名和类的地址放在全局变量的字典里了
下面具体分析

1.类上加上装饰器

在这里插入图片描述
我们发现每个类上都有这个装饰器register,
装饰器获取cls名, cls.name 那么extract_schema是干嘛

2.根据类名和参数保存在另一个类schemadict中

def extract_schema(cls):
    """
    Extract schema from a given class

    Args:
        cls (type): Class from which to extract.

    Returns:
        schema (SchemaDict): Extracted schema.
    """
    ctor = cls.__init__
    # python 2 compatibility
    if hasattr(inspect, 'getfullargspec'):
        argspec = inspect.getfullargspec(ctor)
        annotations = argspec.annotations
        has_kwargs = argspec.varkw is not None
    else:
        argspec = inspect.getfullargspec(ctor)
        # python 2 type hinting workaround, see pep-3107
        # however, since `typeguard` does not support python 2, type checking
        # is still python 3 only for now
        annotations = getattr(ctor, '__annotations__', {})
        has_kwargs = argspec.varkw is not None

    names = [arg for arg in argspec.args if arg != 'self']
    defaults = argspec.defaults
    num_defaults = argspec.defaults is not None and len(argspec.defaults) or 0
    num_required = len(names) - num_defaults

    docs = cls.__doc__
    if docs is None and getattr(cls, '__category__', None) == 'op':
        docs = cls.__call__.__doc__
    try:
        docstring = doc_parse(docs)
    except Exception:
        docstring = None

    if docstring is None:
        comments = {}
    else:
        comments = {}
        for p in docstring.params:
            match_obj = re.match('^([a-zA-Z_]+[a-zA-Z_0-9]*).*', p.arg_name)
            if match_obj is not None:
                comments[match_obj.group(1)] = p.description

    schema = SchemaDict()
    schema.name = cls.__name__
    schema.doc = ""
    if docs is not None:
        start_pos = docs[0] == '\n' and 1 or 0
        schema.doc = docs[start_pos:].split("\n")[0].strip()
    # XXX handle paddle's weird doc convention
    if '**' == schema.doc[:2] and '**' == schema.doc[-2:]:
        schema.doc = schema.doc[2:-2].strip()
    schema.category = hasattr(cls, '__category__') and getattr(
        cls, '__category__') or 'module'
    schema.strict = not has_kwargs
    schema.pymodule = importlib.import_module(cls.__module__)
    schema.inject = getattr(cls, '__inject__', [])
    schema.shared = getattr(cls, '__shared__', [])
    for idx, name in enumerate(names):
        comment = name in comments and comments[name] or name
        if name in schema.inject:
            type_ = None
        else:
            type_ = name in annotations and annotations[name] or None
        value_schema = SchemaValue(name, comment, type_)
        if name in schema.shared:
            assert idx >= num_required, "shared config must have default value"
            default = defaults[idx - num_required]
            value_schema.set_default(SharedConfig(name, default))
        elif idx >= num_required:
            default = defaults[idx - num_required]
            value_schema.set_default(default)
        schema.set_schema(name, value_schema)

    return schema

这段代码就是获取cls的参数和地址,也就是cls.model 保存在schema 这个类中

这个类的定义在下面

class SchemaDict(dict):
    def __init__(self, **kwargs):
        super(SchemaDict, self).__init__()
        self.schema = {}
        self.strict = False
        self.doc = ""
        self.update(kwargs)

    def __setitem__(self, key, value):
        # XXX also update regular dict to SchemaDict??
        if isinstance(value, dict) and key in self and isinstance(self[key],
                                                                  SchemaDict):
            self[key].update(value)
        else:
            super(SchemaDict, self).__setitem__(key, value)

    def __missing__(self, key):
        if self.has_default(key):
            return self.schema[key].default
        elif key in self.schema:
            return self.schema[key]
        else:
            raise KeyError(key)

    def copy(self):
        newone = SchemaDict()
        newone.__dict__.update(self.__dict__)
        newone.update(self)
        return newone

    def set_schema(self, key, value):
        assert isinstance(value, SchemaValue)
        self.schema[key] = value

    def set_strict(self, strict):
        self.strict = strict

    def has_default(self, key):
        return key in self.schema and self.schema[key].has_default()

    def is_default(self, key):
        if not self.has_default(key):
            return False
        if hasattr(self[key], '__dict__'):
            return True
        else:
            return key not in self or self[key] == self.schema[key].default

    def find_default_keys(self):
        return [
            k for k in list(self.keys()) + list(self.schema.keys())
            if self.is_default(k)
        ]

    def mandatory(self):
        return any([k for k in self.schema.keys() if not self.has_default(k)])

    def find_missing_keys(self):
        missing = [
            k for k in self.schema.keys()
            if k not in self and not self.has_default(k)
        ]
        placeholders = [k for k in self if self[k] in ('<missing>', '<value>')]
        return missing + placeholders

    def find_extra_keys(self):
        return list(set(self.keys()) - set(self.schema.keys()))

    def find_mismatch_keys(self):
        mismatch_keys = []
        for arg in self.schema.values():
            if arg.type is not None:
                try:
                    check_type("{}.{}".format(self.name, arg.name),
                               self[arg.name], arg.type)
                except Exception:
                    mismatch_keys.append(arg.name)
        return mismatch_keys

    def validate(self):
        missing_keys = self.find_missing_keys()
        if missing_keys:
            raise ValueError("Missing param for class<{}>: {}".format(
                self.name, ", ".join(missing_keys)))
        extra_keys = self.find_extra_keys()
        if extra_keys and self.strict:
            raise ValueError("Extraneous param for class<{}>: {}".format(
                self.name, ", ".join(extra_keys)))
        mismatch_keys = self.find_mismatch_keys()
        if mismatch_keys:
            raise TypeError("Wrong param type for class<{}>: {}".format(
                self.name, ", ".join(mismatch_keys)))

他继承了dict类

在这里插入图片描述
最重要的两个操作,importlib.import_module 是动态获取cls的地址。
这样全局变量global_config 里面就存储了很多键值对,
在这里插入图片描述
每个valu就是schemadict类型,里存放文件的地址
在这里插入图片描述

3.实例话

那么在什么时候实例化的
1,首先train.py 中的
cfg = load_config(FLAGS.config) 这句会把yaml配置文件中带!的类实例化放进global_config中。其他的200多的类在文件运行的时候就放进global_config中,也就是上面的部分。那么剩下的类只是放进global_config中,并没有实例化。

在这里插入图片描述
在 trainer.py中
在这里插入图片描述

也就是在create中 执行了,
在这里插入图片描述
在这里插入图片描述

那么参数里面包含另一个类怎么实例化呢,比如yolov3 包含了 resnet 、ppyolofpn、head等
在这里插入图片描述
yolov3类中有一个变量叫做 inject = [‘post_process’] ,这个保存在shemdict的inject中
在这里插入图片描述
在这里插入图片描述

还是在creat函数中,上面我隐藏的getattr中
在这里插入图片描述
他会循环调用creat函数 ,把需要的成员变量也实例化注册进实例类中。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值