2021SC@SDUSC
以下是PaddleDetection的配置文件的代码,我经过阅读后,在代码的关键部分和难理解的部分加上了备注:
'''
tree -f#显示显示完整路径
tree -d#显示所有文件夹
└── PaddleDetection #主目录
├── configs #存放配置文件
│ ├──_base_ 各个模块配置
├── dataset #存放数据集,数据集下载脚本,对应各数据集文件夹
│ ├── coco #80类 #物体
│ ├── fddb #1类 #人脸 通常用来评估人脸检测算法
│ ├── mot # 多目标跟踪
│ ├── roadsign_voc #车道线数据 4类
│ ├── voc #20类 #物体
│ └── wider_face #1类 人脸
├── deploy #部署相关
│ ├── cpp #C++部署
│ │ ├── cmake #cmake文件
│ │ ├── docs #部署文档
│ │ ├── include #库头文件
│ │ ├── scripts #依赖库配置脚本,build脚本
│ │ └── src #源码
│ └── python #python部署
├── ppdet #飞桨物体检测套件
│ ├── core #核心部分
│ │ └── config #实例、注册类的配置
│ ├── data #数据处理
│ │ ├── shared_queue #共享队列(数据多线程)
│ │ ├── source #各种数据集类
│ │ ├── tests #测试
│ │ ├── tools #工具(转coco数据格式)
│ │ └── transform #数据增强模块
│ ├── ext_op #增加op
│ │ ├── src #op实现源码
│ │ └── test #op测试
│ ├── modeling #模型结构
│ │ ├── architectures #网络结构
│ │ ├── backbones #主干网络
│ │ ├── heads #头(RPN、loss)
│ │ ├── losses #头(loss)
│ │ ├── mot #多目标跟踪(tracking)
│ │ ├── necks #颈(FPN)
│ │ ├── proposal_generator #建议框生成(anchor、rpnhead)
│ │ ├── reid #重识别
│ │ └── tests #测试
│ ├── py_op #一些前后处理
│ └── utils #实用工具
└── tools #训练,测试,验证,将来还有模型导出
'''
基本设计
#ppdet/core/workspace.py
#如何通过cfg来注册、创建类
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import importlib
import os
import sys
import yaml
import copy
import collections
from .config.schema import SchemaDict, SharedConfig, extract_schema
from .config.yaml_helpers import serializable
__all__ = [
'global_config',
'load_config',
'merge_config',
'get_registered_modules',
'create',
'register',
'serializable',
'dump_value',
]
#__all__变量的值是一个列表,存储的是当前模块中一些成员(变量、函数或者类)的名称。通过在模块文件中设置 __all__ 变量,
#当其它文件以“from 模块名 import *”的形式导入该模块时,该文件中只能使用 __all__ 列表中指定的成员。
#以“from 模块名 import *”形式导入的模块,当该模块设有 __all__ 变量时,只能导入该变量指定的成员,未指定的成员是无法导入的。
#python内置函数 getattr()、hasattr()介绍
'''
getattr(object, name[, default])
getattr() 是python 中的一个内置函数,用于返回一个对象属性值。
object -- 对象名。
name -- 字符串,对象属性。
default -- 默认返回值,如果不提供该参数,在没有对应属性时,将触发 AttributeError。
返回对象属性值。
'''
'''
hasattr(object,'name') 函数用于判断对象是否包含对应的属性。
object -- 对象。
name -- 字符串,属性名。
如果对象有该属性返回 True,否则返回 False。
'''
#除去已经注册类的值
def dump_value(value):
# XXX this is hackish, but collections.abc is not available in python 2
if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)):
value = yaml.dump(value, default_flow_style=True)
value = value.replace('\n', '')
value = value.replace('...', '')
return "'{}'".format(value)
else:
# primitive types
return str(value)
#遍历字典属性
class AttrDict(dict):
"""Single level attribute dict, NOT recursive"""
def __init__(self, **kwargs):
super(AttrDict, self).__init__()
super(AttrDict, self).update(kwargs)
def __getattr__(self, key):
if key in self:
return self[key]
raise AttributeError("object has no attribute '{}'".format(key))
global_config = AttrDict()
READER_KEY = '_READER_'
#加载配置文件函数
def load_config(file_path):
"""
Load config from file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
_, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
cfg = AttrDict()
with open(file_path) as f:
cfg = merge_config(yaml.load(f, Loader=yaml.Loader), cfg)
if READER_KEY in cfg:
reader_cfg = cfg[READER_KEY]
if reader_cfg.startswith("~"):
reader_cfg = os.path.expanduser(reader_cfg)
if not reader_cfg.startswith('/'):
reader_cfg = os.path.join(os.path.dirname(file_path), reader_cfg)
with open(reader_cfg) as f:
merge_config(yaml.load(f, Loader=yaml.Loader))
del cfg[READER_KEY]
merge_config(cfg)
return global_config
#合并配置函数
def dict_merge(dct, merge_dct):
""" Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
updating only top-level keys, dict_merge recurses down into dicts nested
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
``dct``.
Args:
dct: dict onto which the merge is executed
merge_dct: dct merged into dct
Returns: dct
"""
for k, v in merge_dct.items():
if (k in dct and isinstance(dct[k], dict) and
isinstance(merge_dct[k], collections.Mapping)):
dict_merge(dct[k], merge_dct[k])
else:
dct[k] = merge_dct[k]
return dct
#合并全局配置函数
def merge_config(config, another_cfg=None):
"""
Merge config into global config or another_cfg.
Args:
config (dict): Config to be merged.
Returns: global config
"""
global global_config
dct = another_cfg if another_cfg is not None else global_config
return dict_merge(dct, config)
#获取注册模块
def get_registered_modules():
return {k: v for k, v in global_config.items() if isinstance(v, SchemaDict)}
#加载注册表中的模块文件
def make_partial(cls):
op_module = importlib.import_module(cls.__op__.__module__)
op = getattr(op_module, cls.__op__.__name__)
cls.__category__ = getattr(cls, '__category__', None) or 'op'
def partial_apply(self, *args, **kwargs):
kwargs_ = self.__dict__.copy()
kwargs_.update(kwargs)
return op(*args, **kwargs_)
if getattr(cls, '__append_doc__', True): # XXX should default to True?
if sys.version_info[0] > 2:
cls.__doc__ = "Wrapper for `{}` OP".format(op.__name__)
cls.__init__.__doc__ = op.__doc__
cls.__call__ = partial_apply
cls.__call__.__doc__ = op.__doc__
else:
# XXX work around for python 2
partial_apply.__doc__ = op.__doc__
cls.__call__ = partial_apply
return cls
#注册模块类
def register(cls):
"""
Register a given module class.
Args:
cls (type): Module class to be registered.
Returns: cls
"""
if cls.__name__ in global_config:
raise ValueError("Module class already registered: {}".format(
cls.__name__))
if hasattr(cls, '__op__'):
cls = make_partial(cls)
global_config[cls.__name__] = extract_schema(cls)
return cls
#创建模块类
def create(cls_or_name, **kwargs):
"""
Create an instance of given module class.
Args:
cls_or_name (type or str): Class of which to create instance.
Returns: instance of type `cls_or_name`
"""
assert type(cls_or_name) in [type, str
], "should be a class or name of a class"
name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
assert name in global_config and \
isinstance(global_config[name], SchemaDict), \
"the module {} is not registered".format(name)
config = global_config[name]
config.update(kwargs)
config.validate()
cls = getattr(config.pymodule, name)
kwargs = {}
kwargs.update(global_config[name])
# parse `shared` annoation of registered modules
if getattr(config, 'shared', None):
for k in config.shared:
target_key = config[k]
shared_conf = config.schema[k].default
assert isinstance(shared_conf, SharedConfig)
if target_key is not None and not isinstance(target_key,
SharedConfig):
continue # value is given for the module
elif shared_conf.key in global_config:
# `key` is present in config
kwargs[k] = global_config[shared_conf.key]
else:
kwargs[k] = shared_conf.default_value
# parse `inject` annoation of registered modules
if getattr(config, 'inject', None):
for k in config.inject:
target_key = config[k]
# optional dependency
if target_key is None:
continue
if isinstance(target_key, dict) or hasattr(target_key, '__dict__'):
if 'name' not in target_key.keys():
continue
inject_name = str(target_key['name'])
if inject_name not in global_config:
raise ValueError(
"Missing injection name {} and check it's name in cfg file".
format(k))
target = global_config[inject_name]
for i, v in target_key.items():
if i == 'name':
continue
target[i] = v
if isinstance(target, SchemaDict):
kwargs[k] = create(inject_name)
elif isinstance(target_key, str):
if target_key not in global_config:
raise ValueError("Missing injection config:", target_key)
target = global_config[target_key]
if isinstance(target, SchemaDict):
kwargs[k] = create(target_key)
elif hasattr(target, '__dict__'): # serialized object
kwargs[k] = target
else:
raise ValueError("Unsupported injection type:", target_key)
# prevent modification of global config values of reference types
# (e.g., list, dict) from within the created module instances
#kwargs = copy.deepcopy(kwargs)
return cls(**kwargs)
经过我的阅读,并一行一行地进行分析,我有如下收获:
1.通过上面的yaml配置文件,根据需要来传参注册相应的模块(类),模块主要是Data 、Model、Optimizer
a.Data中就包含了数据读取、数据预处理、数据封装。
b.Model中是检测算法的核心部分,model的architecture主要分为Backbone、Neck、Head、Post_process四部分:
- <1>Backbone 主干网络(如:resnet、darknet、mobilenet…)
- <2>Neck 颈(增强特征提取,如SPP,FPN)
- <3>Head 头 (算法核心,YOLO核心思想、Rcnn核心思想、Anchor Free核心思想)
- <4>Post_process 后处理(通常用来将特征与GT对齐)
c.Optimizer中就包含学习率策略、优化器的选择,用来更新网络权重
2.将所需要的模块(数据读取、模型构建、后处理、优化器)组建成目标检测器
3.有了目标检测器我们就可以根据需要用于训练、验证或者推理。