python用yaml装参数并支持命令行修改

效果:

  • 将实验用的参数写入 yaml 文件,而不是全部用 argparse 传,否则命令会很长,而且在 jupyter notebook 内用不了 argparse;
  • 同时支持在命令行临时加、改一些参数,避免事必要在 yaml 中改参数,比较灵活(如 grid-search 时遍历不同的 loss weights)。

最初是在 MMDetection 中看到这种写法,参考 [1] 中 --cfg-options 这个参数,核心是 DictAction 类,定义在 [2]。yaml 一些支持的写法参考 [3]。本文同时作为 python yaml 读、写简例。

Code

  • DictAction 类抄自 [2];
  • parse_cfg 函数读 yaml 参数,并按命令行输入加、改参数(覆盖 yaml),用 EasyDict 装;
  • 用 yaml 备份参数时,用 easydict2dict 将 EasyDict 递归改回 dict,yaml 会干净点。不用也行。
from argparse import Action, ArgumentParser, Namespace
import copy
from easydict import EasyDict
from typing import Any, Optional, Sequence, Tuple, Union
import yaml

class DictAction(Action):
    """抄自 MMEngine
    argparse action to split an argument into KEY=VALUE form
    on the first = and append to a dictionary. List options can
    be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
    brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
    list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
    """

    @staticmethod
    def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
        """parse int/float/bool value in the string."""
        try:
            return int(val)
        except ValueError:
            pass
        try:
            return float(val)
        except ValueError:
            pass
        if val.lower() in ['true', 'false']:
            return True if val.lower() == 'true' else False
        if val == 'None':
            return None
        return val

    @staticmethod
    def _parse_iterable(val: str) -> Union[list, tuple, Any]:
        """Parse iterable values in the string.

        All elements inside '()' or '[]' are treated as iterable values.

        Args:
            val (str): Value string.

        Returns:
            list | tuple | Any: The expanded list or tuple from the string,
            or single value if no iterable values are found.

        Examples:
            >>> DictAction._parse_iterable('1,2,3')
            [1, 2, 3]
            >>> DictAction._parse_iterable('[a, b, c]')
            ['a', 'b', 'c']
            >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
            [(1, 2, 3), ['a', 'b'], 'c']
        """

        def find_next_comma(string):
            """Find the position of next comma in the string.

            If no ',' is found in the string, return the string length. All
            chars inside '()' and '[]' are treated as one element and thus ','
            inside these brackets are ignored.
            """
            assert (string.count('(') == string.count(')')) and (
                    string.count('[') == string.count(']')), \
                f'Imbalanced brackets exist in {string}'
            end = len(string)
            for idx, char in enumerate(string):
                pre = string[:idx]
                # The string before this ',' is balanced
                if ((char == ',') and (pre.count('(') == pre.count(')'))
                        and (pre.count('[') == pre.count(']'))):
                    end = idx
                    break
            return end

        # Strip ' and " characters and replace whitespace.
        val = val.strip('\'\"').replace(' ', '')
        is_tuple = False
        if val.startswith('(') and val.endswith(')'):
            is_tuple = True
            val = val[1:-1]
        elif val.startswith('[') and val.endswith(']'):
            val = val[1:-1]
        elif ',' not in val:
            # val is a single value
            return DictAction._parse_int_float_bool(val)

        values = []
        while len(val) > 0:
            comma_idx = find_next_comma(val)
            element = DictAction._parse_iterable(val[:comma_idx])
            values.append(element)
            val = val[comma_idx + 1:]

        if is_tuple:
            return tuple(values)

        return values

    def __call__(self,
                 parser: ArgumentParser,
                 namespace: Namespace,
                 values: Union[str, Sequence[Any], None],
                 option_string: str = None):
        """Parse Variables in string and add them into argparser.

        Args:
            parser (ArgumentParser): Argument parser.
            namespace (Namespace): Argument namespace.
            values (Union[str, Sequence[Any], None]): Argument string.
            option_string (list[str], optional): Option string.
                Defaults to None.
        """
        # Copied behavior from `argparse._ExtendAction`.
        options = copy.copy(getattr(namespace, self.dest, None) or {})
        if values is not None:
            for kv in values:
                key, val = kv.split('=', maxsplit=1)
                options[key] = self._parse_iterable(val)
        setattr(namespace, self.dest, options)


def parse_cfg(yaml_file, update_dict={}):
    """load configurations from a yaml file & update from command-line argments
    Input:
        yaml_file: str, path to a yaml configuration file
        update_dict: dict, to modify/update options in those yaml configurations
    Output:
        cfg: EasyDict
    """
    with open(args.cfg, "r") as f:
        cfg = EasyDict(yaml.safe_load(f))

    if update_dict:
        assert isinstance(update_dict, dict)
        for k, v in update_dict.items():
            k_list = k.split('.')
            assert len(k_list) > 0
            if len(k_list) == 1: # 单级,e.g. lr=0.1
                cfg[k_list[0]] = v
            else: # 多级,e.g. optimizer.group1.lr=0.2
                ptr = cfg
                for i, _k in enumerate(k_list):
                    if i == len(k_list) - 1: # last layer
                        ptr[_k] = v
                    elif _k not in ptr:
                        ptr[_k] = EasyDict()

                    ptr = ptr[_k]

    return cfg


def easydict2dict(ed):
    """convert EasyDict to dict for clean yaml"""
    d = {}
    for k, v in ed.items():
        if isinstance(v, EasyDict):
            d[k] = easydict2dict(v)
        else:
            d[k] = v
    return d


if "__main__" == __name__:
    # test command:
    #   python config.py --cfg-options int=5 dict2.lr=8 dict2.newdict.newitem=fly

	import pprint
    parser = ArgumentParser()
    parser.add_argument("--cfg", type=str, default="config.yaml", help="指定 yaml")
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    args = parser.parse_args()

	# 命令行临时加、改参数
    pprint.pprint(args.cfg_options) # dict
    # 读 yaml,并按命令行输入加、改参数
    cfg = parse_cfg(args.cfg, args.cfg_options)
    pprint.pprint(cfg)
    # 备份 yaml(写 yaml)
    with open("backup-config.yaml", 'w') as f:
        # yaml.dump(cfg, f) # OK
        yaml.dump(easydict2dict(cfg), f) # cleaner yaml

输入的 config.yaml

  • 语法参考 [3]
# An example yaml configuration file, used in utils/config.py as an example input.
# Ref: https://pyyaml.org/wiki/PyYAMLDocumentation

log_path: ./log
none: [~, null, None]
bool: [true, false, on, off, True, False]
int: 42				# <- 改它
float: 3.14159
list: [LITE, RES_ACID, SUS_DEXT]
list2:
  - -1
  - 0
  - 1
str:
  a
  2
  0.2
  # d: tom
dict: {hp: 13, sp: 5}
dict2:				# <- 加它
  lr: 0.01			# <- 改它
  decay_rate: 0.1
  name: jerry

测试:

# --cfg-options 支持多级指定(用「.」分隔)
python config.py --cfg config.yaml --cfg-options int=5 dict2.lr=8 dict2.newdict.newitem=fly

输出:

{'dict2.lr': 8, 'dict2.newdict.newitem': 'fly', 'int': 5}

{'bool': [True, False, True, False, True, False],
 'dict': {'hp': 13, 'sp': 5},
 'dict2': {'decay_rate': 0.1,
           'lr': 8,							# <- 改了
           'name': 'jerry',
           'newdict': {'newitem': 'fly'}},	# <- 加了
 'float': 3.14159,
 'int': 5,									# <- 改了
 'list': ['LITE', 'RES_ACID', 'SUS_DEXT'],
 'list2': [-1, 0, 1],
 'log_path': './log',
 'none': [None, None, 'None'],
 'str': 'a 2 0.2'}

json + argparse.Namespace

如果是基于别人的代码改,已经用 argparse 写了很长的参数列表,而现在想在 jupyter notebook 内用 checkpoint 跑些东西,可以用 json 存下 argparse 的参数,然后在 notebook 里读此 json 文件,并用 argparse.Namespace 类复原。这可以用作前文 yaml + EasyDict 的简单替代方案。

example

在用原代码 .py 脚本跑训练时,把命令行参数存入 json 文件:

# train.py
import argparse, json

parser = argparse.ArgumentParser()
parser.add_argument("lr", type=float)
# ...此处省略 n 行 parser.add_argument
args = parser.parse_args()

print(type(args), args.__dict__, args.lr) # <class 'argparse.Namespace'> {'lr': 0.1} 0.1

# 将跑此实验的命令行参数存入 json 文件
with open("cmdln-args.json", "w") as f:
    json.dump(args.__dict__, f, indent=1)

训练命令:python train.py 0.1 <以及后续 n 个参数>

在 jupyter notebook 内,不能在命令行传参,改从之前存下的 cmdln-args.json 读参数:

# test.ipynb
import argparse, json

with open("cmdln-args.json", "r") as f:
    args = argparse.Namespace(**json.load(f))

print(type(args), args.__dict__, args.lr) # <class 'argparse.Namespace'> {'lr': 0.1} 0.1

References

  1. open-mmlab/mmdetection/tools/train.py
  2. open-mmlab/mmengine/mmengine/config/config.py
  3. PyYAML Documentation
  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值