Dl4j-fit(DataSetIterator iterator)源码阅读(一)

本文详细解读了DL4J库中fit(DataSetIterator iterator)方法的源码,从网络模型开始,探讨了iterator.asyncSupported()的异步支持检查,trainingListeners的训练监听器,以及layerWiseConfigurations.isBackprop()的反向传播设置。通过TaskUtils.buildTask(iter)创建训练任务,并深入Solver的配置与实例化过程。
摘要由CSDN通过智能技术生成

fit(DataSetIterator iterator)源码阅读

1 网络模型

//Create the network
int numInput = 1;
int numOutputs = 1;
int nHidden = 2;
MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
        .seed(seed)
        .iterations(iterations)
        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
        .learningRate(learningRate)
        .weightInit(WeightInit.XAVIER)
        .updater(Updater.SGD)     //To configure: .updater(new Nesterovs(0.9))
        .list()
        .layer(0, new DenseLayer.Builder().nIn(numInput).nOut(nHidden)
                .activation(Activation.RELU).dropOut(0.5)
                .build())
        .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
                .activation(Activation.IDENTITY)
                .nIn(numInput).nOut(numOutputs).build())
        .pretrain(false).backprop(true).build()
);

调用net.fit(iterator);对源码进行单步阅读。

2 fit(DataSetIterator iterator)

@Override
public void fit(DataSetIterator iterator) {
    DataSetIterator iter;
    // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
    if (iterator.asyncSupported()) {
        iter = new AsyncDataSetIterator(iterator, 2);
    } else {
        iter = iterator;
    }

    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochStart(this);
        }
    }

    if (layerWiseConfigurations.isPretrain()) {
        pretrain(iter);
        if (iter.resetSupported()) {
            iter.reset();
        }
    }
    if (layerWiseConfigurations.isBackprop()) {
        update(TaskUtils.buildTask(iter));
        if (!iter.hasNext() && iter.resetSupported()) {
            iter.reset();
        }
        while (iter.hasNext()) {
            DataSet next = iter.next();
            if (next.getFeatureMatrix() == null || next.getLabels() == null)
                break;

            boolean hasMaskArrays = next.hasMaskArrays();

            if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(next.getFeatureMatrix(), next.getLabels(), next.getFeaturesMaskArray(),
                                next.getLabelsMaskArray());
            } else {
                if (hasMaskArrays)
                    setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
                setInput(next.getFeatureMatrix());
                setLabels(next.getLabels());
                if (solver == null) {
                    solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值