fairseq的奇妙传参
fairseq的运行命令(例如训练wav2vec2)示例如下:
$ fairseq-hydra-train \
task.data=/path/to/data \
--config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \
--config-name wav2vec2_large_librivox
或者
$ python fairseq_cli/hydra_train.py \
task.data=/path/to/data \
--config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \
--config-name wav2vec2_large_librivox
由于需要适应一些训练要求,我需要去掉命令行中传入的参数,改在训练文件中直接传入。可当我在整个fairseq项目里搜索config-dir和config-name时,找不到有parser.add_argument来添加或者定义这两个参数的地方。那这两个参数是如何传入的呢?玄机就在hydra里。这个是hydra_train.py中运行的主函数:
def cli_main():
try:
from hydra._internal.utils import get_args
cfg_name = get_args().config_name or "config"
except:
logger.warning("Failed to get config name from hydra args")
cfg_name = "config"
hydra_init(cfg_name)
hydra_main()
我们在自己环境中python目录下,可以找到hydra安装包的所在位置(如python3.7/site-packages/hydra
),找到hydra/_internal/utils.py
中的get_args函数:
def get_args(args: Optional[Sequence[str]] = None) -> Any:
return get_args_parser().parse_args(args=args)
在其调用的get_args_parser函数中,我们看到了相关三个参数的定义:
parser.add_argument(
"--config-path",
"-cp",
help="""Overrides the config_path specified in hydra.main().
The config_path is absolute or relative to the Python file declaring @hydra.main()""",
)
parser.add_argument(
"--config-name",
"-cn",
help="Overrides the config_name specified in hydra.main()",
)
parser.add_argument(
"--config-dir",
"-cd",
help="Adds an additional config dir to the config search path",
)
可以看到config-path和config-name两个传参是优先于hydra.main()中传入的参数的(fairseq_cli/hydra_train.py原代码中的@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config")
),config-dir是参数搜索路径一个额外的目录。因此如果不想训练时在命令行传入两个参数的话,可以改在hydra.main中传入,形如@hydra.main(config_path="../examples/wav2vec/config/pretraining", config_name="wav2vec2_large_librivox")
.
task.data参数也可以直接在config yaml文件中写上,就不需要命令行传参了。
需要注意的一点是,hydra.main修改以及config文件添加task.data之后,在命令行直接运行python fairseq_cli/hydra_train.py
会报如下错误:
Traceback (most recent call last):
File "fairseq_cli/hydra_train_withconfig.py", line 27, in hydra_main
_hydra_main(cfg)
File "fairseq_cli/hydra_train_withconfig.py", line 49, in _hydra_main
OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
omegaconf.errors.ConfigKeyError: str interpolation key 'common.tpu' not found
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
经排查错误出现在hydra_train.py中的cli_main函数上,没有命令行传参,得到的cfg_name是“config”,而不是“wav2vec2_large_librivox”,因此这里直接将hydra.main中指定的config_name定义为cfg_name即可解决报错:
def cli_main():
#try:
# from hydra._internal.utils import get_args
# cfg_name = get_args().config_name or "config"
#except:
# logger.warning("Failed to get config name from hydra args")
# cfg_name = "config"
cfg_name = "wav2vec2_large_librivox"
hydra_init(cfg_name)
hydra_main()
这样一来就可以在代码中定义config-dir和config-name了,直接使用python fairseq_cli/hydra_train.py
即可开始训练。