fairseq运行命令中的config-dir和config-name参数

fairseq的奇妙传参

fairseq的运行命令(例如训练wav2vec2)示例如下:

$ fairseq-hydra-train \
    task.data=/path/to/data \
    --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \
    --config-name wav2vec2_large_librivox 
   或者
$ python fairseq_cli/hydra_train.py \
    task.data=/path/to/data \
    --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \
    --config-name wav2vec2_large_librivox 

由于需要适应一些训练要求,我需要去掉命令行中传入的参数,改在训练文件中直接传入。可当我在整个fairseq项目里搜索config-dir和config-name时,找不到有parser.add_argument来添加或者定义这两个参数的地方。那这两个参数是如何传入的呢?玄机就在hydra里。这个是hydra_train.py中运行的主函数:

def cli_main():
    try:
        from hydra._internal.utils import get_args

        cfg_name = get_args().config_name or "config"
    except:
        logger.warning("Failed to get config name from hydra args")
        cfg_name = "config"

    hydra_init(cfg_name)
    hydra_main()

我们在自己环境中python目录下,可以找到hydra安装包的所在位置(如python3.7/site-packages/hydra),找到hydra/_internal/utils.py中的get_args函数:

def get_args(args: Optional[Sequence[str]] = None) -> Any:
    return get_args_parser().parse_args(args=args)

在其调用的get_args_parser函数中,我们看到了相关三个参数的定义:

	parser.add_argument(
        "--config-path",
        "-cp",
        help="""Overrides the config_path specified in hydra.main().
                    The config_path is absolute or relative to the Python file declaring @hydra.main()""",
    )

    parser.add_argument(
        "--config-name",
        "-cn",
        help="Overrides the config_name specified in hydra.main()",
    )

    parser.add_argument(
        "--config-dir",
        "-cd",
        help="Adds an additional config dir to the config search path",
    )

可以看到config-path和config-name两个传参是优先于hydra.main()中传入的参数的(fairseq_cli/hydra_train.py原代码中的@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config")),config-dir是参数搜索路径一个额外的目录。因此如果不想训练时在命令行传入两个参数的话,可以改在hydra.main中传入,形如@hydra.main(config_path="../examples/wav2vec/config/pretraining", config_name="wav2vec2_large_librivox").
task.data参数也可以直接在config yaml文件中写上,就不需要命令行传参了。
需要注意的一点是,hydra.main修改以及config文件添加task.data之后,在命令行直接运行python fairseq_cli/hydra_train.py会报如下错误:

Traceback (most recent call last):
  File "fairseq_cli/hydra_train_withconfig.py", line 27, in hydra_main
    _hydra_main(cfg)
  File "fairseq_cli/hydra_train_withconfig.py", line 49, in _hydra_main
    OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
omegaconf.errors.ConfigKeyError: str interpolation key 'common.tpu' not found

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

经排查错误出现在hydra_train.py中的cli_main函数上,没有命令行传参,得到的cfg_name是“config”,而不是“wav2vec2_large_librivox”,因此这里直接将hydra.main中指定的config_name定义为cfg_name即可解决报错:

def cli_main():
    #try:
    #    from hydra._internal.utils import get_args

    #    cfg_name = get_args().config_name or "config"
    #except:
    #    logger.warning("Failed to get config name from hydra args")
    #    cfg_name = "config"
    cfg_name = "wav2vec2_large_librivox"
    hydra_init(cfg_name)
    hydra_main()

这样一来就可以在代码中定义config-dir和config-name了,直接使用python fairseq_cli/hydra_train.py即可开始训练。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值