1.1 GravesLSTM参数个数公式
GravesLSTMParamInitializer中 public int numParams(Layer l)方法的参数个数计算公式如下:
int nParams = nLast * (4 * nL) //"input" weights
+ nL * (4 * nL + 3) //recurrent weights
+ 4 * nL; //bias
想要搞清楚这个公式的由来,就要知道LSTM的具体结构。下图是LSTM的结构图:
遗忘门、输入门、输出门对应的公式如下:
从上述中公式中可以看出,输入xt对应四种权重,则计算参数个数公式中输入参数个数是nLast * (4 * nL) //"input" weights。前一时间步的输出ht-1对应四种权重。由于这个类是GravesLSTM,看源码注释中标明是实现了peephole connections。其结构如图:
所以循环参数个数是nL * (4 * nL + 3) //recurrent weights。这里的+3就是多出来的peephole connections。而LSTM是没有实现peephole connections,LSTMParamInitializer中public int numParams(Layer l)方法的参数个数计算公式如下:
int nParams = nLast * (4 * nL) //"input" weights
+ nL * (4 * nL) //recurrent weights
+ 4 * nL; //bias
1.2 LSTM神经元数据处理流程解读
DL4J源码中LSTMHelpers的static public FwdPassReturn activateHelper()方法集中了LSTM神经元数据处理的主要流程。下面中文注释是对这个方法源码的解读。其中ifogActivations变量比较重要,三个门的数据大多来源于它,重点记住这个变量数据的来源。这个变量是这个方法的纲领。抓住这个纲领,数据处理流程自然就清晰可见了。
if (input ==null || input.length() == 0)
throw new IllegalArgumentException("Invalid input: not set or 0 length");
// 输入数据权重矩阵,行数:输入神经元数,列数:输出神经元个数*4
INDArray inputWeights = originalInputWeights;
// 本神经元前一时间步的输出
INDArray prevOutputActivations =originalPrevOutputActivations;
// 判断输入是否为3维。LSTM的输入结构一般都是3维:[miniBatchSize,inputSize,timeSeriesLength]
boolean is2dInput =input.rank() < 3;
// 获取时间序列长度,就是3维中最后一维的长度
int timeSeriesLength = (is2dInput ? 1 :input.size(2));
// 隐藏层神经元个数,用户自定义
int hiddenLayerSize =recurrentWeights.size(0);
// 小批量每批数据量
int miniBatchSize =input.size(0);
INDArray prevMemCellState;
if (originalPrevMemCellState ==null) {
// 本神经元前一时间步记忆状态,就是LSTM结构图中的Ct-1,初始化时是0
prevMemCellState = Nd4j.create(new int[] { miniBatchSize,hiddenLayerSize}, 'f');
} else {
prevMemCellState = originalPrevMemCellState.dup('f');
}
// 循环权重中输入、遗忘、输出三个门的权重。因为DL4J中有的LSTM支持窥视孔,所以recurrentWeights中除了三个门的权重,还有窥视孔的权重。
INDArray recurrentWeightsIFOG =recurrentWeights
.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 *hiddenLayerSize)).dup('f');
if (conf.isUseDropConnect() &&training && conf.getLayer().getDropOut() > 0) {
inputWeights = Dropout.applyDropConnect(layer,inputWeightKey);
}
INDArray wFFTranspose = null;
INDArray wOOTranspose = null;
INDArray wGGTranspose = null;
// 是否支持窥视孔
if (hasPeepholeConnections) {
// 下面三个是窥视孔的三个门权重
wFFTranspose = recurrentWeights
.get(NDArrayIndex.all(), interval(4 *hiddenLayerSize, 4 * hiddenLayerSize + 1))
.transpose();
wOOTranspose = recurrentWeights
.get(NDArrayIndex.all(), interval(4 *hiddenLayerSize + 1, 4 * hiddenLayerSize + 2))
.transpose();
wGGTranspose = recurrentWeights
.get(NDArrayIndex.all(), interval(4 *hidd