一、安装
train:
pip install optuna
web board:
pip install optuna-dashboard
二、 示例
执行 quick_start.py
import optuna
def objective(trial):
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
z = trial.suggest_float("z", -10, 10)
return x ** 2 + y + z
if __name__ == "__main__":
study = optuna.create_study(
storage="sqlite:///db.sqlite3", # Specify the storage URL here.
study_name="quadratic-simple3"
)
study.optimize(objective, n_trials=100)
print(f"Best value: {study.best_value} (params: {study.best_params})")
web board 查看
optuna-dashboard sqlite:///db.sqlite3
结果展示:
score优化示例
import optuna
study = optuna.create_study()
print(f"Sampler is {study.sampler.__class__.__name__}")
study = optuna.create_study(sampler=optuna.samplers.RandomSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
study = optuna.create_study(sampler=optuna.samplers.CmaEsSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
import logging
import sys
import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection
def objective(trial):
iris = sklearn.datasets.load_iris()
classes = list(set(iris.target))
train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(
iris.data, iris.target, test_size=0.25, random_state=0
)
alpha = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)
clf = sklearn.linear_model.SGDClassifier(alpha=alpha)
for step in range(100):
clf.partial_fit(train_x, train_y, classes=classes)
# Report intermediate objective value.
intermediate_value = 1.0 - clf.score(valid_x, valid_y)
trial.report(intermediate_value, step)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.TrialPruned()
return 1.0 - clf.score(valid_x, valid_y)
# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)
print(study.trials_dataframe())
print("===")
输出:
三、总结
-
使用简单,能灵活自定义object funcion
-
有Dashboard 能实时的监控数据,进行画图