LSTMHelpers类中的backpropGradientHelper方法是梯度计算过程。
// 本层神经元个数
int hiddenLayerSize = recurrentWeights.size(0); //i.e., n^L // 前一层神经元个数
int prevLayerSize = inputWeights.size(0); //n^(L-1)
// 一批数据量
int miniBatchSize = epsilon.size(0);
boolean is2dInput = epsilon.rank() < 3; //Edge case: T=1 may have shape [miniBatchSize,n^(L+1)], equiv. to [miniBatchSize,n^(L+1),1]
int timeSeriesLength = (is2dInput ? 1 : epsilon.size(2));
INDArray wFFTranspose = null;
INDArray wOOTranspose = null;
INDArray wGGTranspose = null;
// 窥视孔
if (hasPeepholeConnections) {
// 下面三个是三个窥视孔对应的权重
wFFTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize)).transpose();
wOOTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 1)).transpose();
wGGTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 2)).transpose();
}
// 从循环权重矩阵中分离出三个门的权重。从代码效率上看,这条语句应该和上面的窥视孔判断结合起来。如果没有窥视孔,可以直接让wIFOG = recurrentWeights。有窥视孔才这样操作。
INDArray wIFOG = recurrentWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));
//F order here so that content for time steps are together
// 这个epsilonNext和本层输入矩阵的结构是一致的。
INDArray epsilonNext = Nd4j.create(new int[] { miniBatchSize, prevLayerSize, timeSeriesLength}, 'f'); //i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T]
INDArray nablaCellStateNext = null;
// 这里初始化了一个全为0的矩阵,后面四个矩阵从这个矩阵中截取。当后面四个矩阵变化时,deltaifogNext 也会相应变化。
INDArray deltaifogNext = Nd4j.create(new int[] { miniBatchSize, 4 * hiddenLayerSize}, 'f');
INDArray deltaiNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
INDArray deltafNext = deltaifogNext.get(NDArrayIndex.all(),
NDArrayIndex.interval(hiddenLayerSize, 2 * hiddenLayerSize));
INDArray deltaoNext = deltaifogNext.get(NDArrayIndex.all(),
NDArrayIndex.interval(2 * hiddenLayerSize, 3 * hiddenLayerSize));
INDArray deltagNext = deltaifogNext.get(NDArrayIndex.all(),
NDArrayIndex.interval(3 * hiddenLayerSize, 4 * hiddenLayerSize));
Level1 l1BLAS = Nd4j.getBlasWrapper().level1();
int endIdx = 0;
if (truncatedBPTT) {
endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength);
}
// 从梯度视图中获取输入、循环、偏移三种权重,并都置为0
INDArray iwGradientsOut = gradientViews.get(inputWeightKey);
INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey); //Order: {I,F,O,G,FF,OO,GG}
INDArray bGradientsOut = gradientViews.get(biasWeightKey);
iwGradientsOut.assign(0);
rwGradientsOut.assign(0);
bGradientsOut.assign(0);
// 都是0,和上边说过的一样,应该和下面窥视孔的判断结合起来。
INDArray rwGradientsIFOG =
rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));
INDArray rwGradientsFF = null;
INDArray rwGradientsOO = null;
INDArray rwGradientsGG = null;
if (hasPeepholeConnections) {
rwGradientsFF = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize));
rwGradientsOO = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 1));
rwGradientsGG = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 2));
}
if (helper !=