Dl4j-fit(DataSetIterator iterator)源码阅读(八) 根据参数更新梯度

前面经过反向传播,已经计算出了模型的损失函数得分以及梯度,在反向传播完成之后会返回到package org.deeplearning4j.optimize.solvers包下的BaseOptimizer.gradientAndScore()方法体重继续执行,该方法体中继续执行。

反向传播计算完的参数还需要经过梯度正则化以及L1,L2参数惩罚

@Override
public Pair<Gradient, Double> gradientAndScore() {
    oldScore = score;
    //包含反向传播,已经算出了模型的损失函数得分以及梯度
    model.computeGradientAndScore();

    if (iterationListeners != null && iterationListeners.size() > 0) {
        for (IterationListener l : iterationListeners) {
            if (l instanceof TrainingListener) {
                ((TrainingListener) l).onGradientCalculation(model);
            }
        }
    }

    //获取模型中的梯度和损失函数得分
    Pair<Gradient, Double> pair = model.gradientAndScore();
    //将模型的损失函数得分赋值为优化器的成员变量中
    score = pair.getSecond();
    //然后根据参数更新梯度
    updateGradientAccordingToParams(pair.getFirst(), model, model.batchSize());
    return pair;
}

1 updateGradientAccordingToParams

@Override
public void updateGradientAccordingToParams(Gradient gradient, Model model, int batchSize) {
    //首先判断是ComputationGraph还是MultiLayerNetwork
    if (model instanceof ComputationGraph) {
        ComputationGraph graph = (ComputationGraph) model;
        if (computationGraphUpdater == null) {
            computationGraphUpdater = new ComputationGraphUpdater(graph);
        }
        computationGraphUpdater.update(graph, gradient, getIterationCount(model), batchSize);
    } else {

        //获取更新器
        if (updater == null)
            updater = UpdaterCreator.getUpdater(model);

        //将model改为Layer类型,这个时候需要注意,在多层网络架构的时候
        //MultiLayerNetwork 可以认为是输出层
        //MultiLayerNetwork is a neural network with multiple layers in a stack, and usually an output layer.
        Layer layer = (Layer) model;
        updater.update(layer, gradient, getIterationCount(model), batchSize);
    }
}

1.1 UpdaterCreator.getUpdater(model)

首先需要根据模型设置来获取模型参数的更新器

public class UpdaterCreator {

    private UpdaterCreator() {}

    public static org.deeplearning4j.nn.api.Updater getUpdater(Model layer) {
        //判断网络架构
        if (layer instanceof MultiLayerNetwork) {
            return new MultiLayerUpdater((MultiLayerNetwork) layer);
        } else {
            return new LayerUpdater();
        }
    }

}

之后构造一个新的更新器类MultiLayerUpdater
package org.deeplearning4j.nn.updater;包下,所调用的更新器的构造函数为:

/**
 * MultiLayerUpdater: Gradient updater for MultiLayerNetworks.
 * Expects backprop gradients for all layers to be in single Gradient object,
 * keyed by "0_b", "1_w" etc., as per MultiLayerNetwork.backward()
 */
public MultiLayerUpdater(MultiLayerNetwork network) {
    //获取架构的网络层
    Layer[] layers = network.getLayers();
    //逐层判断是否为空
    for (int i = 0; i < layers.length; i++) {
        //守护条件,保证获取到的layer全都不为null
        while (layers[i] == null)
            layers = network.getLayers();
    }
    //根据网络层个数构造网络层更新器
    layerUpdaters = new Updater[layers.length];
    //更新器状态个数
    int updaterStateSize = 0;
    for (int i = 0; i < layers.length; i++) {
        Layer layer = layers[i];
        //这里依旧判断当前层是否为空,如果为空则会跑出空指针有慈航
        Preconditions.checkNotNull(layer);
        //根据当前网络层构建层更新器
        layerUpdaters[i] = UpdaterCreator.getUpdater(layer);

        //这里的更新器因为使用的是SGD,所以StateSize这里不管传入什么值,返回的均为0
        updaterStateSize += layerUpdaters[i].stateSizeForLayer(layer);
    }
    //初始化更新器状态
    //Initialize the updater state:
    if (updaterStateSize > 0) {
        //May be 0 if all SGD updaters, for example
        viewArray = Nd4j.createUninitialized(new int[] {
  1, updaterStateSize}, Nd4j.order());
    }

    //需要跨越多远获取子视图
    int soFar = 0;
    for (int i = 0; i < layers.length; i++) {
        //获取更新器状态
        int thisSize = layerUpdaters[i].stateSizeForLayer(layers[i]);

        //如果为0
        if (thisSize == 0)
            continue;

        //如果不为0,则获取子视图
        INDArray view = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + thisSize));

        //设置到对应的更新器中
        layerUpdaters[i].setStateViewArray(layers[i], view, true);
        soFar += thisSize;
    }
}

到这里MultiLayerUpdater执行完成,继续返回上层函数

1 updateGradientAccordingToParams

里面继续执行以下语句

updater.update(layer, gradient, getIterationCount(model), batchSize);

1.2 updater.update()

@Override
public void update(Layer layer, Gradient gradient, int iteration, int batchSize) {
    MultiLayerNetwork mln = (MultiLayerNetwork) layer;

    //根据LayerUpdaters的个数构建 层梯度 的个数
    Gradient[] layerGradients = new Gradient[layerUpdaters.length];
    //实例化层梯度
    for (int i = 0; i < layerGradients.length; i++)
        layerGradients[i] = new DefaultGradient();

    //然后遍历已经计算好的梯度
    for (Map.Entry<String, INDArray> gradien
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值