fit(DataSetIterator iterator)源码阅读
1 网络模型
//Create the network
int numInput = 1;
int numOutputs = 1;
int nHidden = 2;
MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(learningRate)
.weightInit(WeightInit.XAVIER)
.updater(Updater.SGD) //To configure: .updater(new Nesterovs(0.9))
.list()
.layer(0, new DenseLayer.Builder().nIn(numInput).nOut(nHidden)
.activation(Activation.RELU).dropOut(0.5)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.IDENTITY)
.nIn(numInput).nOut(numOutputs).build())
.pretrain(false).backprop(true).build()
);
调用net.fit(iterator);
对源码进行单步阅读。
2 fit(DataSetIterator iterator)
@Override
public void fit(DataSetIterator iterator) {
DataSetIterator iter;
// we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
if (iterator.asyncSupported()) {
iter = new AsyncDataSetIterator(iterator, 2);
} else {
iter = iterator;
}
if (trainingListeners.size() > 0) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochStart(this);
}
}
if (layerWiseConfigurations.isPretrain()) {
pretrain(iter);
if (iter.resetSupported()) {
iter.reset();
}
}
if (layerWiseConfigurations.isBackprop()) {
update(TaskUtils.buildTask(iter));
if (!iter.hasNext() && iter.resetSupported()) {
iter.reset();
}
while (iter.hasNext()) {
DataSet next = iter.next();
if (next.getFeatureMatrix() == null || next.getLabels() == null)
break;
boolean hasMaskArrays = next.hasMaskArrays();
if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(next.getFeatureMatrix(), next.getLabels(), next.getFeaturesMaskArray(),
next.getLabelsMaskArray());
} else {
if (hasMaskArrays)
setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
setInput(next.getFeatureMatrix());
setLabels(next.getLabels());
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(