[拆轮子] PaddleDetection中__shared__、__inject__ 和 from_config 三者分别做了什么

在上一篇中,PaddleDetection Register装饰器到底做了什么
https://blog.csdn.net/HaoZiHuang/article/details/128668393

已经介绍了 __shared____inject__ 的作用:

  • __inject__ 表示引入全局字典中已经封装好的模块。如loss等。
  • __shared__为了实现一些参数的配置全局共享,这些参数可以被backbone, neck,head,loss等所有注册模块共享。

PaddleDetection 文档是这么说的,可是我还是不太懂。于是看了下源码,建议先看上边那篇文章,里边写了在哪部分 __inject__ 列表 和 __shared__列表被读取的。

标题中的三者都是在 ppdet/core/workspace.pycreate 函数使用的,create 函数用于创建已经被 Register装饰的注册过的类

1. __shared__ 部分

在 create 函数中先进行有效性检验, cls_or_name 可以是类别名称的字符串,也可以是已经写好的类,但在 PaddleDetection 当前版本内容,大概率只是字符串

    assert type(cls_or_name) in [type, str
                                 ], "should be a class or name of a class"
    name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
    if name in global_config:
        if isinstance(global_config[name], SchemaDict):
        	# 如果 cls_or_name 这个类已经注册,则 global_config.values 元素是 SchemaDict
            pass
            
        elif hasattr(global_config[name], "__dict__"):
            # support instance return directly
            # 如果有 __dict__ 则直接返回hhhh( 当前版本用的不多 )
            return global_config[name]
            
        else:
            raise ValueError("The module {} is not registered".format(name))
    else:
        raise ValueError("The module {} is not registered".format(name))

之后解析 __shared__ 列表中的内容

    # parse `shared` annoation of registered modules
    if getattr(config, 'shared', None):
        for k in config.shared:
            target_key = config[k]
            shared_conf = config.schema[k].default
            assert isinstance(shared_conf, SharedConfig)
            if target_key is not None and \
                   not isinstance(target_key, SharedConfig):
                continue  # 如果当前当前 target_key 不是SharedConfig, 那么参数已被传入
			
			# 
            elif shared_conf.key in global_config:
                # `key` is present in config
                cls_kwargs[k] = global_config[shared_conf.key]  # 必须在全局设置! __shared__ (num_classes之类的)
            else:
                cls_kwargs[k] = shared_conf.default_value       # 否则就搞默认的

而之后的几行如果在全局配置过,比如这样:
在这里插入图片描述
则读取全局配置的内容

2. from_config 部分

之后执行:

    if getattr(cls, 'from_config', None):
        cls_kwargs.update(cls.from_config(config, **kwargs))

由于 backbone neck head 之间的配置可能存在耦合,于是部分类实例化时,可能需要之前模块的配置,所以要在 architecture 初始化时,创建 neck head 之类的

给个例子看吧,transformer 和 detr_head 创建时除了读取之前 config 的内容,也传入了来自前置模块的内容

    @classmethod
    def from_config(cls, cfg, *args, **kwargs):
        # backbone
        backbone = create(cfg['backbone'])
        # transformer
        kwargs = {'input_shape': backbone.out_shape}
        transformer = create(cfg['transformer'], **kwargs)
        # head
        kwargs = {
            'hidden_dim': transformer.hidden_dim,
            'nhead': transformer.nhead,
            'input_shape': backbone.out_shape
        }
        detr_head = create(cfg['detr_head'], **kwargs)

        return {
            'backbone': backbone,
            'transformer': transformer,
            "detr_head": detr_head,
        }

3. __inject__ 部分

__inject__ 部分其实与 from_config 很像,都是将类实例化为对象,来看一小部分

在这里插入图片描述

k'loss',之前在 __inject__ 列表中
target_key'DETRLoss' 是一个字符串

	target_key = config[k]
	......
	
    elif isinstance(target_key, str):
        if target_key not in global_config:
            raise ValueError("Missing injection config:", target_key)
        target = global_config[target_key]
        if isinstance(target, SchemaDict):
            cls_kwargs[k] = create(target_key)   # 在此处将类实例化
        elif hasattr(target, '__dict__'):  # serialized object
            cls_kwargs[k] = target

可以看到 from_config 是由于组件之间存在参数耦合,要在前者创建完毕后,将部分参数传给后者,所以要借助 create API 手动实例化

__inject__ 的使用很简单,只许在 __inject__ 中指定对应的参数即可,如上图中指定了 loss 部分,而 loss 参数是 DETRLoss,于是 loss 传入后是一个 实例化的 DETRLoss 对象

4. 附录 create 函数源码

def create(cls_or_name, **kwargs):
    """
    Create an instance of given module class.

    Args:
        cls_or_name (type or str): Class of which to create instance.

    Returns: instance of type `cls_or_name`
    """
    assert type(cls_or_name) in [type, str
                                 ], "should be a class or name of a class"
    name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
    if name in global_config:
        if isinstance(global_config[name], SchemaDict):
            pass
        elif hasattr(global_config[name], "__dict__"):
            # support instance return directly
            return global_config[name]
        else:
            raise ValueError("The module {} is not registered".format(name))
    else:
        raise ValueError("The module {} is not registered".format(name))

    config = global_config[name]
    cls = getattr(config.pymodule, name)
    cls_kwargs = {}
    cls_kwargs.update(global_config[name])

    # parse `shared` annoation of registered modules
    if getattr(config, 'shared', None):
        for k in config.shared:
            target_key = config[k]
            shared_conf = config.schema[k].default
            assert isinstance(shared_conf, SharedConfig)
            if target_key is not None and not isinstance(target_key,
                                                         SharedConfig): # 如果我指定则就传入指定的
                continue  # value is given for the module
            elif shared_conf.key in global_config:
                # `key` is present in config
                cls_kwargs[k] = global_config[shared_conf.key]  # 必须在全局设置! __shared__ (num_classes之类的)
            else:
                cls_kwargs[k] = shared_conf.default_value       # 否则就搞默认的

    # parse `inject` annoation of registered modules
    if getattr(cls, 'from_config', None):
        cls_kwargs.update(cls.from_config(config, **kwargs))

    if getattr(config, 'inject', None):
        for k in config.inject:
            target_key = config[k]
            # optional dependency
            if target_key is None:
                continue

            if isinstance(target_key, dict) or hasattr(target_key, '__dict__'):
                if 'name' not in target_key.keys():
                    continue
                inject_name = str(target_key['name'])
                if inject_name not in global_config:
                    raise ValueError(
                        "Missing injection name {} and check it's name in cfg file".
                        format(k))
                target = global_config[inject_name]
                for i, v in target_key.items():
                    if i == 'name':
                        continue
                    target[i] = v
                if isinstance(target, SchemaDict):
                    cls_kwargs[k] = create(inject_name)
            elif isinstance(target_key, str):
                if target_key not in global_config:
                    raise ValueError("Missing injection config:", target_key)
                target = global_config[target_key]
                if isinstance(target, SchemaDict):
                    cls_kwargs[k] = create(target_key)
                elif hasattr(target, '__dict__'):  # serialized object
                    cls_kwargs[k] = target
            else:
                raise ValueError("Unsupported injection type:", target_key)
    # prevent modification of global config values of reference types
    # (e.g., list, dict) from within the created module instances
    #kwargs = copy.deepcopy(kwargs)
    return cls(**cls_kwargs)

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值