PaddleDetection代码解析之配置文件分析

2021SC@SDUSC
以下是PaddleDetection的配置文件的代码,我经过阅读后,在代码的关键部分和难理解的部分加上了备注:

'''
tree -f#显示显示完整路径
tree -d#显示所有文件夹
└── PaddleDetection #主目录
    ├── configs  #存放配置文件
    │    ├──_base_ 各个模块配置
    ├── dataset  #存放数据集,数据集下载脚本,对应各数据集文件夹
    │   ├── coco #80类 #物体
    │   ├── fddb #1类  #人脸 通常用来评估人脸检测算法
    │   ├── mot # 多目标跟踪
    │   ├── roadsign_voc #车道线数据 4类
    │   ├── voc #20类 #物体
    │   └── wider_face #1类 人脸
    ├── deploy  #部署相关
    │   ├── cpp #C++部署
    │   │   ├── cmake #cmake文件
    │   │   ├── docs #部署文档
    │   │   ├── include #库头文件
    │   │   ├── scripts #依赖库配置脚本,build脚本
    │   │   └── src #源码
    │   └── python #python部署
    ├── ppdet   #飞桨物体检测套件
    │   ├── core #核心部分
    │   │   └── config #实例、注册类的配置
    │   ├── data #数据处理
    │   │   ├── shared_queue #共享队列(数据多线程)
    │   │   ├── source #各种数据集类
    │   │   ├── tests #测试
    │   │   ├── tools #工具(转coco数据格式)
    │   │   └── transform #数据增强模块
    │   ├── ext_op  #增加op
    │   │   ├── src #op实现源码
    │   │   └── test #op测试
    │   ├── modeling #模型结构
    │   │   ├── architectures #网络结构
    │   │   ├── backbones #主干网络
    │   │   ├── heads #头(RPN、loss)
    │   │   ├── losses #头(loss)
    │   │   ├── mot #多目标跟踪(tracking)
    │   │   ├── necks #颈(FPN)
    │   │   ├── proposal_generator #建议框生成(anchor、rpnhead)
    │   │   ├── reid #重识别
    │   │   └── tests #测试
    │   ├── py_op #一些前后处理
    │   └── utils #实用工具
    └── tools #训练,测试,验证,将来还有模型导出
'''
基本设计
#ppdet/core/workspace.py
#如何通过cfg来注册、创建类
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import importlib
import os
import sys

import yaml
import copy
import collections

from .config.schema import SchemaDict, SharedConfig, extract_schema
from .config.yaml_helpers import serializable


__all__ = [
    'global_config',
    'load_config',
    'merge_config',
    'get_registered_modules',
    'create',
    'register',
    'serializable',
    'dump_value',
]

#__all__变量的值是一个列表,存储的是当前模块中一些成员(变量、函数或者类)的名称。通过在模块文件中设置 __all__ 变量,
#当其它文件以“from 模块名 import *”的形式导入该模块时,该文件中只能使用 __all__ 列表中指定的成员。
#以“from 模块名 import *”形式导入的模块,当该模块设有 __all__ 变量时,只能导入该变量指定的成员,未指定的成员是无法导入的。

#python内置函数 getattr()、hasattr()介绍

'''
getattr(object, name[, default])
getattr() 是python 中的一个内置函数,用于返回一个对象属性值。

object -- 对象名。
name -- 字符串,对象属性。
default -- 默认返回值,如果不提供该参数,在没有对应属性时,将触发 AttributeError。

返回对象属性值。
'''

'''
hasattr(object,'name') 函数用于判断对象是否包含对应的属性。

object -- 对象。
name -- 字符串,属性名。

如果对象有该属性返回 True,否则返回 False。
'''

#除去已经注册类的值
def dump_value(value):
    # XXX this is hackish, but collections.abc is not available in python 2
    if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)):
        value = yaml.dump(value, default_flow_style=True)
        value = value.replace('\n', '')
        value = value.replace('...', '')
        return "'{}'".format(value)
    else:
        # primitive types
        return str(value)

#遍历字典属性
class AttrDict(dict):
    """Single level attribute dict, NOT recursive"""

    def __init__(self, **kwargs):
        super(AttrDict, self).__init__()
        super(AttrDict, self).update(kwargs)

    def __getattr__(self, key):
        if key in self:
            return self[key]
        raise AttributeError("object has no attribute '{}'".format(key))


global_config = AttrDict()

READER_KEY = '_READER_'

#加载配置文件函数
def load_config(file_path):
    """
    Load config from file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"

    cfg = AttrDict()
    with open(file_path) as f:
        cfg = merge_config(yaml.load(f, Loader=yaml.Loader), cfg)

    if READER_KEY in cfg:
        reader_cfg = cfg[READER_KEY]
        if reader_cfg.startswith("~"):
            reader_cfg = os.path.expanduser(reader_cfg)
        if not reader_cfg.startswith('/'):
            reader_cfg = os.path.join(os.path.dirname(file_path), reader_cfg)

        with open(reader_cfg) as f:
            merge_config(yaml.load(f, Loader=yaml.Loader))
        del cfg[READER_KEY]

    merge_config(cfg)
    return global_config

#合并配置函数
def dict_merge(dct, merge_dct):
    """ Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
    updating only top-level keys, dict_merge recurses down into dicts nested
    to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
    ``dct``.
    Args:
        dct: dict onto which the merge is executed
        merge_dct: dct merged into dct
    Returns: dct
    """
    for k, v in merge_dct.items():
        if (k in dct and isinstance(dct[k], dict) and
                isinstance(merge_dct[k], collections.Mapping)):
            dict_merge(dct[k], merge_dct[k])
        else:
            dct[k] = merge_dct[k]
    return dct

#合并全局配置函数
def merge_config(config, another_cfg=None):
    """
    Merge config into global config or another_cfg.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
    global global_config
    dct = another_cfg if another_cfg is not None else global_config
    return dict_merge(dct, config)

#获取注册模块
def get_registered_modules():
    return {k: v for k, v in global_config.items() if isinstance(v, SchemaDict)}

#加载注册表中的模块文件
def make_partial(cls):
    op_module = importlib.import_module(cls.__op__.__module__)
    op = getattr(op_module, cls.__op__.__name__)
    cls.__category__ = getattr(cls, '__category__', None) or 'op'

    def partial_apply(self, *args, **kwargs):
        kwargs_ = self.__dict__.copy()
        kwargs_.update(kwargs)
        return op(*args, **kwargs_)

    if getattr(cls, '__append_doc__', True):  # XXX should default to True?
        if sys.version_info[0] > 2:
            cls.__doc__ = "Wrapper for `{}` OP".format(op.__name__)
            cls.__init__.__doc__ = op.__doc__
            cls.__call__ = partial_apply
            cls.__call__.__doc__ = op.__doc__
        else:
            # XXX work around for python 2
            partial_apply.__doc__ = op.__doc__
            cls.__call__ = partial_apply
    return cls

#注册模块类
def register(cls):
    """
    Register a given module class.
    Args:
        cls (type): Module class to be registered.
    Returns: cls
    """
    if cls.__name__ in global_config:
        raise ValueError("Module class already registered: {}".format(
            cls.__name__))
    if hasattr(cls, '__op__'):
        cls = make_partial(cls)
    global_config[cls.__name__] = extract_schema(cls)
    return cls

#创建模块类
def create(cls_or_name, **kwargs):
    """
    Create an instance of given module class.
    Args:
        cls_or_name (type or str): Class of which to create instance.
    Returns: instance of type `cls_or_name`
    """
    assert type(cls_or_name) in [type, str
                                 ], "should be a class or name of a class"
    name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
    assert name in global_config and \
        isinstance(global_config[name], SchemaDict), \
        "the module {} is not registered".format(name)
    config = global_config[name]
    config.update(kwargs)
    config.validate()
    cls = getattr(config.pymodule, name)
    kwargs = {}
    kwargs.update(global_config[name])

    # parse `shared` annoation of registered modules
    if getattr(config, 'shared', None):
        for k in config.shared:

            target_key = config[k]
            shared_conf = config.schema[k].default
            assert isinstance(shared_conf, SharedConfig)
            if target_key is not None and not isinstance(target_key,
                                                         SharedConfig):
                continue  # value is given for the module
            elif shared_conf.key in global_config:
                # `key` is present in config
                kwargs[k] = global_config[shared_conf.key]
            else:
                kwargs[k] = shared_conf.default_value

    # parse `inject` annoation of registered modules
    if getattr(config, 'inject', None):
        for k in config.inject:
            target_key = config[k]
            # optional dependency
            if target_key is None:
                continue

            if isinstance(target_key, dict) or hasattr(target_key, '__dict__'):
                if 'name' not in target_key.keys():
                    continue
                inject_name = str(target_key['name'])
                if inject_name not in global_config:
                    raise ValueError(
                        "Missing injection name {} and check it's name in cfg file".
                        format(k))
                target = global_config[inject_name]
                for i, v in target_key.items():
                    if i == 'name':
                        continue
                    target[i] = v
                if isinstance(target, SchemaDict):
                    kwargs[k] = create(inject_name)
            elif isinstance(target_key, str):
                if target_key not in global_config:
                    raise ValueError("Missing injection config:", target_key)
                target = global_config[target_key]
                if isinstance(target, SchemaDict):
                    kwargs[k] = create(target_key)
                elif hasattr(target, '__dict__'):  # serialized object
                    kwargs[k] = target
            else:
                raise ValueError("Unsupported injection type:", target_key)
    # prevent modification of global config values of reference types
    # (e.g., list, dict) from within the created module instances
    #kwargs = copy.deepcopy(kwargs)
    return cls(**kwargs)

经过我的阅读,并一行一行地进行分析,我有如下收获:

1.通过上面的yaml配置文件,根据需要来传参注册相应的模块(类),模块主要是Data 、Model、Optimizer
    a.Data中就包含了数据读取、数据预处理、数据封装。
    b.Model中是检测算法的核心部分,model的architecture主要分为Backbone、Neck、Head、Post_process四部分:

  • <1>Backbone 主干网络(如:resnet、darknet、mobilenet…)
  • <2>Neck 颈(增强特征提取,如SPP,FPN)
  • <3>Head 头 (算法核心,YOLO核心思想、Rcnn核心思想、Anchor Free核心思想)
  • <4>Post_process 后处理(通常用来将特征与GT对齐)

  c.Optimizer中就包含学习率策略、优化器的选择,用来更新网络权重

2.将所需要的模块(数据读取、模型构建、后处理、优化器)组建成目标检测器

3.有了目标检测器我们就可以根据需要用于训练、验证或者推理。
 

  • 4
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值