一、 安装
x86
pip install nni
other 源码安装
git clone https://github.com/microsoft/nni.git
cd nni
pip install --upgrade setuptools pip wheel
python setup.py develop
二、 用法示例
2.1 提交服务命令:
nnictl create --config config.yml --port 8080
2.2 定义超参空间:
{
"C": {"_type":"uniform","_value":[0.1, 1]},
"kernel": {"_type":"choice","_value":["linear", "rbf", "poly", "sigmoid"]},
"degree": {"_type":"choice","_value":[1, 2, 3, 4]},
"gamma": {"_type":"uniform","_value":[0.01, 0.1]},
"coef0": {"_type":"uniform","_value":[0.01, 0.1]}
}
2.3 定义运行参数:
运行时间限制
运行测试
运行脚本
运行环境
search_space:
features:
_type: choice
_value: [ 128, 256, 512, 1024 ]
lr:
_type: loguniform
_value: [ 0.0001, 0.1 ]
momentum:
_type: uniform
_value: [ 0, 1 ]
trial_command: python model.py
trial_code_directory: .
trial_concurrency: 2
max_trial_number: 10
tuner:
name: TPE
class_args:
optimize_mode: maximize
training_service:
platform: local
2.4 定义执行算法脚本:
import nni
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_digits
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
import logging
import numpy as np
LOG = logging.getLogger('sklearn_classification')
def load_data():
'''Load dataset, use 20newsgroups dataset'''
digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(
digits.data, digits.target, random_state=99, test_size=0.25)
ss = StandardScaler()
X_train = ss.fit_transform(X_train)
X_test = ss.transform(X_test)
return X_train, X_test, y_train, y_test
def get_default_parameters():
'''get default parameters'''
params = {
'C': 1.0,
'kernel': 'linear',
'degree': 3,
'gamma': 0.01,
'coef0': 0.01
}
return params
def get_model(PARAMS):
'''Get model according to parameters'''
model = SVC()
model.C = PARAMS.get('C')
model.kernel = PARAMS.get('kernel')
model.degree = PARAMS.get('degree')
model.gamma = PARAMS.get('gamma')
model.coef0 = PARAMS.get('coef0')
return model
def run(X_train, X_test, y_train, y_test, model):
'''Train model and predict result'''
model.fit(X_train, y_train)
score = model.score(X_test, y_test)
LOG.debug('score: %s', score)
nni.report_final_result(score)
if __name__ == '__main__':
X_train, X_test, y_train, y_test = load_data()
try:
# get parameters from tuner
RECEIVED_PARAMS = nni.get_next_parameter()
LOG.debug(RECEIVED_PARAMS)
PARAMS = get_default_parameters()
PARAMS.update(RECEIVED_PARAMS)
LOG.debug(PARAMS)
model = get_model(PARAMS)
run(X_train, X_test, y_train, y_test, model)
except Exception as exception:
LOG.exception(exception)
raise
2.5 Web portal UI
可以设置训练时常
portal 8080 服务自动根据yml 文件去做参数筛选,优化算法提交脚本
不同训练结果准确率展示
三、总结
优点:
-
代码结构清晰,能灵活增加算法
-
参数配置灵活自定义
-
支持不同平台环境算法训练
-
有较好的UI及结果展示界面
缺点:
-
是通过专门web服务的方式分次进行提交任务,yml,conf 等配置文件依次提交,不太适合我们的提交方式
-
获取结果得从对应的服务获取结果,按照定义的配置执行