简介:
Arbiter是术语DL4J套件下的一个组件,用于优化神经网络的超参数的, 它允许用户提供超参数调优范围,该框架可自行进行各类参数自动化调整。并提供可视化的参数调整分析报告。
<!-- Arbiter - used for hyperparameter optimization (grid/random search) -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-deeplearning4j</artifactId>
<version>1.0.0-beta3</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-ui_2.11</artifactId>
<version>1.0.0-beta3</version>
</dependency>
前提知识储备:
在使用Arbiter之前需要了解D4LJ中的几项配置:
- NeuralNetworkConfiguration
- MultilayerNetworkConfiguration
- ComputationGraphconfiguration
建立一个超参数优化执行器的过程:
- 创建超参数搜索空间
- 创建针对超参数搜索空间的候选参数生成器
- 接下来的步骤可以顺序不定:
- 创建数据源
- 创建模型保存方法
- 创建得分函数
- 创建终止条件
- 使用以上创建的对象创建优化配置对象
- 使用优化配置对象构建优化执行器
`public class ArbiterTest {
public static void main(String args[]) throws IOException, InterruptedException{
//学习率参数区间
ParameterSpace<Double> learningRateHyperparam = new ContinuousParameterSpace(0.0001, 0.1);
//层级数量
ParameterSpace<Integer> layerSizeHyperparam = new IntegerParameterSpace(16,256);
//构建参数搜索空间对象
MultiLayerSpace hyperparameterSpace = new MultiLayerSpace.Builder()
.weightInit(WeightInit.XAVIER)
.l2(0.0001)
.updater(new SgdSpace(learningRateHyperparam)) //学习率
.addLayer(new DenseLayerSpace.Builder()
.nIn(784)
.activation(Activation.LEAKYRELU)
.nOut(layerSizeHyperparam)
.build()
)
.addLayer(new OutputLayerSpace.Builder()
.nOut(10)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT)
.build()
)
.numEpochs(2)
.build();
//构建候选生成器
//随机生成
CandidateGenerator candidateGenerator = new RandomSearchGenerator(hyperparameterSpace, null);
//数据源构建
Class<? extends DataSource> dataSourceClass = ExampleDataSource.class;
Properties dataSourceProperties = new Properties();
dataSourceProperties.setProperty("minibatchSize", "64");
//模型保存方式构建
String baseSaveDir = "arbiterExample/";
File f = new File(baseSaveDir);
if(f.exists())f.delete();
f.mkdir();
ResultSaver modelSaver = new FileModelSaver(baseSaveDir);
//评估方法构建,得分函数
ScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.ACCURACY);
//构建终止条件
TerminationCondition[] terminationConditions = {
new MaxTimeCondition(15, TimeUnit.MINUTES),
new MaxCandidatesCondition(10)
};
//创建模型参数优化配置
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator) //候选生成器 提出一些方案进行评估 候选的生成会基于一些策略
.dataSource(dataSourceClass, dataSourceProperties) //数据源 提供测试与训练的数据 DataSource: DataSource is used under the hood to provide data to the generated candidates for training and test
.modelSaver(modelSaver) //模型保存 制定模型保存的策略Specifies how the results of each hyperparameter optimization run should be saved
.scoreFunction(scoreFunction) //得分方法 例如:损失函数值或者分类精度 ScoreFunction: A metric that is a single number that we are seeking to minimize or maximize to determine the best candidate.
.terminationConditions(terminationConditions) //终止条件 候选条件个数或者运行时长,TerminationCondition: Determines when hyperparameter optimization should be stopped. Eg. A given number of candidates have been evaluated, a certain amount of computation time has passed.
.build();
//创建优化执行器
//如果候选生产器 是MultiLayerNetworks 则执行器创建如下
// IOptimizationRunner runner = new LocalOptimizationRunner(configuration,new MultiLayerNetworkTaskCreator()); //如果候选生成器是ComputationGraphs 则执行器创建如下 IOptimizationRunner runner = new LocalOptimizationRunner(configuration,new MultiLayerNetworkTaskCreator()); //以上执行器 采用的是本地单JVM进程执行,实际情况下可以使用其他执行引擎:spark 或 cloud machine
//启动 UI 服务, 提供可视化优化过程展示
StatsStorage ss = new FileStatsStorage(new File("arbiterExampleUiStats.dl4j"));
runner.addListeners(new ArbiterStatusListener(ss));
UIServer.getInstance().attach(ss);
//执行超参数优化器
runner.execute();
//打印优化过程 Print out some basic stats regarding the optimization procedure
String s = "Best score: " + runner.bestScore() + "\n" +
"Index of model with best score: " + runner.bestScoreCandidateIndex() + "\n" +
"Number of configurations evaluated: " + runner.numCandidatesCompleted() + "\n";
System.out.println(s);
//Get all results, and print out details of the best result:
int indexOfBestResult = runner.bestScoreCandidateIndex();
List<ResultReference> allResults = runner.getResults();
OptimizationResult bestResult = allResults.get(indexOfBestResult).getResult();
MultiLayerNetwork bestModel = (MultiLayerNetwork) bestResult.getResultReference().getResultModel();
System.out.println("\n\nConfiguration of best model:\n");
System.out.println(bestModel.getLayerWiseConfigurations().toJson());
//Wait a while before exiting
Thread.sleep(60000);
UIServer.getInstance().stop();
}
public static class ExampleDataSource implements DataSource {
private int minibatchSize;
public ExampleDataSource() {
}
[@Override](https://my.oschina.net/u/1162528)
public void configure(Properties properties) {
this.minibatchSize = Integer.parseInt(properties.getProperty("minibatchSize", "16"));
}
[@Override](https://my.oschina.net/u/1162528)
public Object trainData() {
try {
return new MnistDataSetIterator(minibatchSize, true, 12345);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
[@Override](https://my.oschina.net/u/1162528)
public Object testData() {
try {
return new MnistDataSetIterator(minibatchSize, false, 12345);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
[@Override](https://my.oschina.net/u/1162528)
public Class<?> getDataType() {
return DataSetIterator.class;
}
}
} `