一、安装NNI
pip install nni
不要用豆瓣的源,依赖包不全,建议用阿里云的源。
import nni
没有问题的时候就说明装好了。
二、配置
2.1 项目代码包含NNI
# 引入nni
import nni
# 这是个字典
params = vars(get_params())
# 从参数空间的json文件中按照某种策略去除一组
tuner_params= nni.get_next_parameter()
# 更新参数字典
params.update(tuner_params)
# 上报中间结果
nni.report_intermediate_result(test_acc)
# 上报最终结果
nni.report_final_result(best_acc)
2.2 参数空间JSON
# search_space.json 这个文件放哪都行
{
"dropout_rate":{"_type":"uniform","_value":[0.5, 0.9]},
"conv_size":{"_type":"choice","_value":[2,3,5,7]},
"hidden_size":{"_type":"choice","_value":[124, 512, 1024]},
"batch_size": {"_type":"choice", "_value": [1, 4, 8, 16, 32]},
"learning_rate":{"_type":"ch