NLD4J - Arbiter

简介

  Arbiter是术语DL4J套件下的一个组件,用于优化神经网络的超参数的, 它允许用户提供超参数调优范围,该框架可自行进行各类参数自动化调整。并提供可视化的参数调整分析报告。

<!-- Arbiter - used for hyperparameter optimization (grid/random search) -->
<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>arbiter-deeplearning4j</artifactId>
    <version>1.0.0-beta3</version>
</dependency>
<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>arbiter-ui_2.11</artifactId>
    <version>1.0.0-beta3</version>
</dependency>

前提知识储备

在使用Arbiter之前需要了解D4LJ中的几项配置:

  • NeuralNetworkConfiguration
  • MultilayerNetworkConfiguration
  • ComputationGraphconfiguration

建立一个超参数优化执行器的过程

  • 创建超参数搜索空间
  • 创建针对超参数搜索空间的候选参数生成器
  • 接下来的步骤可以顺序不定:
    • 创建数据源
    • 创建模型保存方法
    • 创建得分函数
    • 创建终止条件
  • 使用以上创建的对象创建优化配置对象
  • 使用优化配置对象构建优化执行器

`public class ArbiterTest {

public static void main(String args[]) throws IOException, InterruptedException{
	
	//学习率参数区间
	ParameterSpace<Double> learningRateHyperparam = new ContinuousParameterSpace(0.0001, 0.1);
	//层级数量
	ParameterSpace<Integer> layerSizeHyperparam = new IntegerParameterSpace(16,256);
	
	//构建参数搜索空间对象
	MultiLayerSpace hyperparameterSpace = new MultiLayerSpace.Builder()
			.weightInit(WeightInit.XAVIER)
			.l2(0.0001)
			.updater(new SgdSpace(learningRateHyperparam))    //学习率
			.addLayer(new DenseLayerSpace.Builder()
					.nIn(784)
					.activation(Activation.LEAKYRELU)
					.nOut(layerSizeHyperparam)
					.build()
					)
			.addLayer(new OutputLayerSpace.Builder()
					.nOut(10)
					.activation(Activation.SOFTMAX)
					.lossFunction(LossFunctions.LossFunction.MCXENT)
					.build()
					)
			.numEpochs(2)
			.build();
	
	
	//构建候选生成器
	//随机生成
	CandidateGenerator candidateGenerator = new RandomSearchGenerator(hyperparameterSpace, null);
		
	//数据源构建
	Class<? extends DataSource> dataSourceClass = ExampleDataSource.class;
	Properties dataSourceProperties = new Properties();
    dataSourceProperties.setProperty("minibatchSize", "64");
    
    //模型保存方式构建
    String baseSaveDir = "arbiterExample/";
    File f = new File(baseSaveDir);
    if(f.exists())f.delete();
    f.mkdir();
    ResultSaver modelSaver = new FileModelSaver(baseSaveDir);
    
    //评估方法构建,得分函数
    ScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.ACCURACY);
    
    //构建终止条件
    TerminationCondition[] terminationConditions = {
    		new MaxTimeCondition(15, TimeUnit.MINUTES),
    		new MaxCandidatesCondition(10)
    };
    
    
	//创建模型参数优化配置
	OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
			.candidateGenerator(candidateGenerator)                       //候选生成器       提出一些方案进行评估    候选的生成会基于一些策略
			.dataSource(dataSourceClass, dataSourceProperties)                         //数据源  提供测试与训练的数据  DataSource: DataSource is used under the hood to provide data to the generated candidates for training and test 
			.modelSaver(modelSaver)                               //模型保存   制定模型保存的策略Specifies how the results of each hyperparameter optimization run should be saved
			.scoreFunction(scoreFunction)                                     //得分方法   例如:损失函数值或者分类精度  ScoreFunction: A metric that is a single number that we are seeking to minimize or maximize to determine the best candidate. 
			.terminationConditions(terminationConditions)     //终止条件   候选条件个数或者运行时长,TerminationCondition: Determines when hyperparameter optimization should be stopped. Eg. A given number of candidates have been evaluated, a certain amount of computation time has passed.
			.build();
	
	//创建优化执行器
	//如果候选生产器 是MultiLayerNetworks  则执行器创建如下

// IOptimizationRunner runner = new LocalOptimizationRunner(configuration,new MultiLayerNetworkTaskCreator()); //如果候选生成器是ComputationGraphs 则执行器创建如下 IOptimizationRunner runner = new LocalOptimizationRunner(configuration,new MultiLayerNetworkTaskCreator()); //以上执行器 采用的是本地单JVM进程执行,实际情况下可以使用其他执行引擎:spark 或 cloud machine

	//启动 UI 服务, 提供可视化优化过程展示
	StatsStorage ss = new FileStatsStorage(new File("arbiterExampleUiStats.dl4j"));
	runner.addListeners(new ArbiterStatusListener(ss));
    UIServer.getInstance().attach(ss);
	
    
    //执行超参数优化器
    runner.execute();

    
    //打印优化过程 Print out some basic stats regarding the optimization procedure
    String s = "Best score: " + runner.bestScore() + "\n" +
        "Index of model with best score: " + runner.bestScoreCandidateIndex() + "\n" +
        "Number of configurations evaluated: " + runner.numCandidatesCompleted() + "\n";
    System.out.println(s);


    //Get all results, and print out details of the best result:
    int indexOfBestResult = runner.bestScoreCandidateIndex();
    List<ResultReference> allResults = runner.getResults();

    OptimizationResult bestResult = allResults.get(indexOfBestResult).getResult();
    MultiLayerNetwork bestModel = (MultiLayerNetwork) bestResult.getResultReference().getResultModel();

    System.out.println("\n\nConfiguration of best model:\n");
    System.out.println(bestModel.getLayerWiseConfigurations().toJson());


    //Wait a while before exiting
    Thread.sleep(60000);
    UIServer.getInstance().stop();
			
}

 public static class ExampleDataSource implements DataSource {
        private int minibatchSize;

        public ExampleDataSource() {

        }

        [@Override](https://my.oschina.net/u/1162528)
        public void configure(Properties properties) {
            this.minibatchSize = Integer.parseInt(properties.getProperty("minibatchSize", "16"));
        }

        [@Override](https://my.oschina.net/u/1162528)
        public Object trainData() {
            try {
                return new MnistDataSetIterator(minibatchSize, true, 12345);

            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        [@Override](https://my.oschina.net/u/1162528)
        public Object testData() {
            try {
                return new MnistDataSetIterator(minibatchSize, false, 12345);

            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        [@Override](https://my.oschina.net/u/1162528)
        public Class<?> getDataType() {
            return DataSetIterator.class;
        }
    }

} `

转载于:https://my.oschina.net/yjwxh/blog/3007666

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值