背景
编码项目代码时,往往涉及到很多的超参数。Hydra可以帮助整理这些超参数,使实验过程中的参数设置更清晰。
从config.yaml文件中读取超参数
假设当前的文件路径为:
├─configs
│ └─config.yaml
└─main.py
config.yaml的内容是:
name: exp
save_dir: ./checkpoint
data:
dataroot: ./data
batch_size: 64
main.py的内容是:
from omegaconf import DictConfig, OmegaConf
import hydra
@hydra.main(version_base=None, config_path="configs", config_name="config")
def func(cfg):
cfg_str = OmegaConf.to_yaml(cfg) # 将cfg转换成string格式,方便打印
print(cfg_str)
if __name__ == "__main__":
func()
在@hydra.main
中,config_path代表存放配置文件的文件夹,config_name代表主配置文件的名称。@hydra.main读取这些配置后形成数据格式DictConfig,然后传递给func。
读取多个.yaml的超参数
假设当前的文件路径为:
├─configs
│ ├─data
│ │ └─data_1.yaml
│ ├─model
│ │ └─model_1.yaml
│ └─config.yaml
└─main.py
.yaml的内容分别是:
# config.yaml
name: exp
save_dir: ./checkpoint
defaults:
- data: data_1
- model: model_1
# data_1.yaml
dataroot: ./data
# model_1.yaml
n_layers: 3
执行main.py后输出:
data:
dataroot: ./data
model:
n_layers: 3
name: exp
save_dir: ./checkpoint