DL4J源码阅读(七):LSTM梯度计算

本文深入探讨了DL4J库中LSTM的backpropGradientHelper方法,详细解释了梯度计算的过程,包括隐藏层和前一层神经元的数量、批处理大小、窥视孔权重等关键概念,以及如何处理不同时间步长的误差反向传播。通过理解这些步骤,可以更好地理解和优化LSTM网络的训练过程。
摘要由CSDN通过智能技术生成

    LSTMHelpers类中的backpropGradientHelper方法是梯度计算过程。

 

        // 本层神经元个数

        int hiddenLayerSize = recurrentWeights.size(0); //i.e., n^L        // 前一层神经元个数

        int prevLayerSize = inputWeights.size(0); //n^(L-1)

        // 一批数据量

        int miniBatchSize = epsilon.size(0);

        boolean is2dInput = epsilon.rank() < 3; //Edge case: T=1 may have shape [miniBatchSize,n^(L+1)], equiv. to [miniBatchSize,n^(L+1),1]

        int timeSeriesLength = (is2dInput ? 1 : epsilon.size(2));

 

        INDArray wFFTranspose = null;

        INDArray wOOTranspose = null;

        INDArray wGGTranspose = null;

        // 窥视孔

        if (hasPeepholeConnections) {

            // 下面三个是三个窥视孔对应的权重

            wFFTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize)).transpose();

            wOOTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 1)).transpose();

            wGGTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 2)).transpose();

        }

 

        // 从循环权重矩阵中分离出三个门的权重。从代码效率上看,这条语句应该和上面的窥视孔判断结合起来。如果没有窥视孔,可以直接让wIFOG = recurrentWeights。有窥视孔才这样操作。

        INDArray wIFOG = recurrentWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));

        //F order here so that content for time steps are together

        // 这个epsilonNext和本层输入矩阵的结构是一致的。

        INDArray epsilonNext = Nd4j.create(new int[] { miniBatchSize, prevLayerSize, timeSeriesLength}, 'f'); //i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T]

 

        INDArray nablaCellStateNext = null;

        // 这里初始化了一个全为0的矩阵,后面四个矩阵从这个矩阵中截取。当后面四个矩阵变化时,deltaifogNext 也会相应变化。

        INDArray deltaifogNext = Nd4j.create(new int[] { miniBatchSize, 4 * hiddenLayerSize}, 'f');

        INDArray deltaiNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));

        INDArray deltafNext = deltaifogNext.get(NDArrayIndex.all(),

                        NDArrayIndex.interval(hiddenLayerSize, 2 * hiddenLayerSize));

        INDArray deltaoNext = deltaifogNext.get(NDArrayIndex.all(),

                        NDArrayIndex.interval(2 * hiddenLayerSize, 3 * hiddenLayerSize));

        INDArray deltagNext = deltaifogNext.get(NDArrayIndex.all(),

                        NDArrayIndex.interval(3 * hiddenLayerSize, 4 * hiddenLayerSize));

 

        Level1 l1BLAS = Nd4j.getBlasWrapper().level1();

        int endIdx = 0;

 

        if (truncatedBPTT) {

            endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength);

        }

 

        // 从梯度视图中获取输入、循环、偏移三种权重,并都置为0

        INDArray iwGradientsOut = gradientViews.get(inputWeightKey);

        INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey); //Order: {I,F,O,G,FF,OO,GG}

        INDArray bGradientsOut = gradientViews.get(biasWeightKey);

        iwGradientsOut.assign(0);

        rwGradientsOut.assign(0);

        bGradientsOut.assign(0);

 

        // 都是0,和上边说过的一样,应该和下面窥视孔的判断结合起来。

        INDArray rwGradientsIFOG =

                        rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));

        INDArray rwGradientsFF = null;

        INDArray rwGradientsOO = null;

        INDArray rwGradientsGG = null;

        if (hasPeepholeConnections) {

            rwGradientsFF = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize));

            rwGradientsOO = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 1));

            rwGradientsGG = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 2));

        }

 

        if (helper !=

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值