public class IrisLocal {
public static void main(String[] args) throws Exception {
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local[*]");
sparkConf.setAppName("Iris");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
//Load the data from local (driver) classpath into a JavaRDD<DataSet>, for training
//CSVRecordReader converts CSV data (as a String) into usable format for network training
RecordReader recordReader = new CSVRecordReader(0,",");
File f = new File("src/main/resources/iris_shuffled_normalized_csv.txt");
JavaRDD<String> irisDataLines = sc.textFile(f.getAbsolutePath());
//labelIndex变量指向目标向量在记录中的索引
int labelIndex = 4;
int numOutputClasses = 3;
//分别为每条记录创建特征向量和目标向量,目标向量根据numOutputClasses变量的个数以及记录中所给的目标索引确定,如目标索引为2,numOutputClasses为3,则目标向量为<0,1,0>
JavaRDD<DataSet> trainingData = irisDataLines.map(new RecordReaderFunction(recordReader, labelIndex, numOutputClasses)) ;
//First: Create and initialize multi-layer network. Configuration is the same as in normal (non-distributed) DL4J training
final int numInputs = 4;
int outputNum = 3;
int iterations = 1;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.iterations(iterations)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(0.5)
.regularization(true).l2(1e-4)
.activation("tanh")
.weightInit(WeightInit.XAVIER)
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3).build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax")
.nIn(2).nOut(outputNum).build())
.backprop(true).pretrain(false)
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
//Second: Set up the Spark training.
//Set up the TrainingMaster. The TrainingMaster controls how learning is actually executed on Spark
//Here, we are using standard parameter averaging
int examplesPerDataSetObject = 1;
ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(examplesPerDataSetObject)
.workerPrefetchNumBatches(2) //Asynchronously prefetch up to 2 batches
.saveUpdater(true)
.averagingFrequency(1) //See comments on averaging frequency in LSTM example. Averaging every 1 iteration is inefficient in practical problems
.batchSizePerWorker(8) //Number of examples that each worker gets, per fit operation
.build();
SparkDl4jMultiLayer sparkNetwork = new SparkDl4jMultiLayer(sc,net,tm);
int nEpochs = 100;
for( int i=0; i<nEpochs; i++ ){
sparkNetwork.fit(trainingData);
}
//Finally: evaluate the (training) data accuracy in a distributed manner:
Evaluation evaluation = sparkNetwork.evaluate(trainingData);
System.out.println(evaluation.stats());
}
}
上面是程序示例,主要实现的功能是:在spark环境下进行神经网络的训练
Evaluation evaluation = sparkNetwork.evaluate(trainingData);
进入
SparkDl4jMultiLayer类中evalute方法,其中传递的参数分别为trainingdata,null,64
data.mapPartitions()方法需要一个FlatMapFunction<Iterator<DataSet>, Evaluation>参数,这里使用子类来实例化,newpublic Evaluation evaluate(JavaRDD<DataSet> data, List<String> labelsList, int evalBatchSize) { Broadcast listBroadcast = labelsList == null?null:this.sc.broadcast(labelsList); JavaRDD evaluations = data.mapPartitions(new EvaluateFlatMapFunction(this.sc.broadcast(this.conf.toJson()), this.sc.broadcast(this.network.params()), evalBatchSize, listBroadcast)); return (Evaluation)evaluations.reduce(new EvaluationReduceFunction()); }
EvaluateFlatMapFunction()方法中参数分别为SparkDl4jMultiLayer对象的json格式,即对象的json格式,第二个参数为MultiLayerNetwork net = new MultiLayerNetwork(conf)
MultiLayerNetwork对象中的flattenedParams
变量,该变量为神经网络中的权值加偏移量的总和,最后的两个参数为64和null。EvaluateFlatMapFunction()方法中实现了上层接口FlatMapFunction<Iterator<DataSet>, Evaluation>的call方法,该方法主要完成 神经网络训练结果的测试。具体什么意思不是很明白,public Iterable<Evaluation> call(Iterator<DataSet> dataSetIterator) throws Exception { if(!dataSetIterator.hasNext()) { return Collections.emptyList(); } else { MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String)this.json.getValue())); network.init(); INDArray val = (INDArray)this.params.value(); if(val.length() != network.numParams(false)) { throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters"); } else { network.setParameters(val); Evaluation evaluation; if(this.labels != null) { evaluation = new Evaluation((List)this.labels.getValue()); } else { evaluation = new Evaluation(); } ArrayList collect = new ArrayList(); int totalCount = 0; while(dataSetIterator.hasNext()) { collect.clear(); int nExamples = 0; DataSet data; while(dataSetIterator.hasNext() && nExamples < this.evalBatchSize) { data = (DataSet)dataSetIterator.next(); nExamples += data.numExamples(); collect.add(data); } totalCount += nExamples; data = DataSet.merge(collect, false); INDArray out; if(data.hasMaskArrays()) { out = network.output(data.getFeatureMatrix(), false, data.getFeaturesMaskArray(), data.getLabelsMaskArray()); } else { out = network.output(data.getFeatureMatrix(), false); } if(data.getLabels().rank() == 3) { if(data.getLabelsMaskArray() == null) { evaluation.evalTimeSeries(data.getLabels(), out); } else { evaluation.evalTimeSeries(data.getLabels(), out, data.getLabelsMaskArray()); } } else { evaluation.eval(data.getLabels(), out); } } if(log.isDebugEnabled()) { log.debug("Evaluated {} examples ", Integer.valueOf(totalCount)); } return Collections.singletonList(evaluation); } } }
这个返回值主要是对各个分区最后得到的结果进行合并。(Evaluation)evaluations.reduce(new EvaluationReduceFunction());
public void merge(Evaluation other) { if(other != null) { this.truePositives.incrementAll(other.truePositives); this.falsePositives.incrementAll(other.falsePositives); this.trueNegatives.incrementAll(other.trueNegatives); this.falseNegatives.incrementAll(other.falseNegatives); if(this.confusion == null) { if(other.confusion != null) { this.confusion = new ConfusionMatrix(other.confusion); } } else if(other.confusion != null) { this.confusion.add(other.confusion); } this.numRowCounter += other.numRowCounter; if(this.labelsList.isEmpty()) { this.labelsList.addAll(other.labelsList); } } }