使用argparse库和yacs库进行深度学习项目的配置
argparse库使用
argparse
模块可以让人轻松编写用户友好的命令行接口。 程序定义它需要哪些参数,argparse
将会知道如何从 sys.argv
解析它们。 argparse
模块还能自动生成帮助和用法消息文本。 该模块还会在用户向程序传入无效参数时发出错误消息。
使用方法:
①创建解析器
parser = argparse.ArgumentParser(description='your description')
ArgumentParser 对象包含将命令行解析成 Python 数据类型所需的全部信息。
②添加参数
给一个 ArgumentParser 添加程序参数信息是通过调用 add_argument() 方法完成的。通常,这些调用指定 ArgumentParser 如何获取命令行字符串并将其转换为对象。这些信息在 parse_args() 调用时被存储和使用。例如:
parser.add_argument('--cfg',
help='experiment configure file name',
required=False,
type=str,
default = '../experiments/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml'
)
parser.add_argument('--seed', type=int, default=304)
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
然后,调用 parse_args() 将返回一个具有 integers 和 accumulate 两个属性的对象。integers 属性将是一个包含一个或多个整数的列表,而 accumulate 属性当命令行中指定了 --sum 参数时将是 sum() 函数,否则则是 max() 函数。
③解析参数
ArgumentParser 通过 parse_args() 方法解析参数。它将检查命令行,把每个参数转换为适当的类型然后调用相应的操作。在大多数情况下,这意味着一个简单的 Namespace 对象将从命令行解析出的属性构建:
args = parser.parse_args() # 命令行参数进行解析
在脚本中,通常 parse_args() 会被不带参数调用,而 ArgumentParser 将自动从 sys.argv 中确定命令行参数
yacs库使用
YACS 被创建为一个轻量级库来定义和管理系统配置,例如在为科学实验而设计的软件中常见的配置。这些“配置”通常涵盖用于训练机器学习模型的超参数或可配置模型超参数等概念,例如卷积神经网络的深度。当做实验时,其可重复性至关重要,因此需要一种可靠的方法来序列化实验配置。YACS 使用 YAML 作为一种简单的、人类可读的序列化格式。范式是:your code + a YACS config for experiment E (+ external dependencies + hardware + other nuisance terms ...) = reproducible experiment E
。虽然您可能无法控制所有内容,但至少您可以控制代码和实验性配置。
YACS主要有两种使用方式:
- 全局变量
- 局部变量
作者建议使用局部变量。
①引入库
要在项目中使用 YACS,请首先创建一个项目配置文件,通常称为config.py
或 defaults.py
。此文件是所有可配置选项的一站式参考点。它应该有很好的文档记录,并为所有选项提供合理的默认值。
# my_project/config.py
from yacs.config import CfgNode as CN
_C = CN()
_C.SYSTEM = CN()
# 设置实验中GPU的数量
_C.SYSTEM.NUM_GPUS = 8
# 设置训练时的线程数
_C.SYSTEM.NUM_WORKERS = 4
_C.TRAIN = CN()
# 一个非常重要的超参数,不知道干啥的
_C.TRAIN.HYPERPARAMETER_1 = 0.1
# 非常重要的尺度
_C.TRAIN.SCALES = (2, 4, 8, 16)
def get_cfg_defaults():
"""Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
return _C.clone()
# Alternatively, provide a way to import the defaults as
# a global singleton:
# cfg = _C # users can `from config import cfg`
②设置YAML文件
创建YAML配置文件,通常情况下应该为每个实验创建一个,每个配置文件只覆盖一次会在实验中改变的选项参数。
# my_project/experiment.yaml
SYSTEM:
NUM_GPUS: 2
TRAIN:
SCALES: (1, 2)
③应用配置参数
在实际项目中使用配置文件。在初始化参数设置之后,**最好是将其冻结,以防在后续的操作中程序调用freeze()
方法修改参数。**如下面的代码所示,设置系统参数可以通过导入cfg
,并直接通过它使用全局变量的方法来设置,或者cfg
作为参数进行复制和传递。
# my_project/main.py
import my_project
from config import get_cfg_defaults # 局部变量使用方式, or:
# from config import cfg # 全局变量使用方式
if __name__ == "__main__":
# `cfg`作为局部变量访问的例子
cfg = get_cfg_defaults()
cfg.merge_from_file("experiment.yaml")
cfg.freeze()
print(cfg)
# `cfg`作为全局变量访问的例子
if cfg.SYSTEM.NUM_GPUS > 0:
my_project.setup_multi_gpu_support()
model = my_project.create_model(cfg)
联合使用Argparse、YACS库
主要是使用argparse库获取YAML文件地址和输入一些变量,并使用yacs库更新yaml里面的配置信息,后续主要使用yacs中的_C
中包含的信息。
defaults.py
:
# my_project/config.py
from yacs.config import CfgNode as CN
_C = CN()
_C.SYSTEM = CN()
# 设置实验中GPU的数量
_C.SYSTEM.NUM_GPUS = 8
# 设置训练时的线程数
_C.SYSTEM.NUM_WORKERS = 4
_C.TRAIN = CN()
# 一个非常重要的超参数,不知道干啥的
_C.TRAIN.HYPERPARAMETER_1 = 0.1
# 非常重要的尺度
_C.TRAIN.SCALES = (2, 4, 8, 16)
def update_config(cfg, args):
cfg.defrost() # 解冻参数
cfg.merge_from_file(args.cfg)
cfg.merge_from_list(args.opts)
cfg.freeze()
YAML文件
# my_project/experiment.yaml
SYSTEM:
NUM_GPUS: 2
TRAIN:
SCALES: (1, 2)
main.py
import argparse
from defaults import _C as config
from defaults import update_config
def parse_args():
parser = argparse.ArgumentParser(description='Train segmentation network')
parser.add_argument('--cfg',
help='experiment configure file name',
required=False,
type=str,
default='./yaml_file.yaml'
)
parser.add_argument('--seed', type=int, default=304)
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args() # 命令行参数进行解析
update_config(config, args)
return args
if __name__ == '__main__':
args= parse_args()
print("args:\n",args,"\n")
print("config:\n",config)
附件资源中代码可作参考,代码同上。
参考网页: