DL4J源码阅读(六):LSTM信号前传处理流程

本文详细介绍了DL4J中LSTM的GravesLSTMParamInitializer参数个数公式,并解析了LSTM神经元数据处理流程。重点探讨了输入权重、循环权重和偏置的计算,以及LSTM激活辅助方法activateHelper()中的核心步骤,包括数据验证、矩阵运算和门控计算等。
摘要由CSDN通过智能技术生成

1.1 GravesLSTM参数个数公式

        GravesLSTMParamInitializerpublic int numParams(Layer l)方法的参数个数计算公式如下:

        int nParams = nLast * (4 * nL) //"input" weights

              + nL * (4 * nL + 3) //recurrent weights

              + 4 * nL; //bias

    想要搞清楚这个公式的由来,就要知道LSTM的具体结构。下图是LSTM的结构图:

 

 

        遗忘门、输入门、输出门对应的公式如下:

 




        从上述中公式中可以看出,输入xt对应四种权重,则计算参数个数公式中输入参数个数是nLast * (4 * nL) //"input" weights。前一时间步的输出ht-1对应四种权重。由于这个类是GravesLSTM,看源码注释中标明是实现了peephole connections。其结构如图:

 

        所以循环参数个数是nL * (4 * nL + 3) //recurrent weights。这里的+3就是多出来的peephole connections。而LSTM是没有实现peephole connectionsLSTMParamInitializerpublic int numParams(Layer l)方法的参数个数计算公式如下:        
         int nParams = nLast * (4 * nL) //"input" weights

                + nL * (4 * nL) //recurrent weights

                + 4 * nL; //bias

 

1.2 LSTM神经元数据处理流程解读

        DL4J源码中LSTMHelpersstatic public FwdPassReturn activateHelper()方法集中了LSTM神经元数据处理的主要流程。下面中文注释是对这个方法源码的解读。其中ifogActivations变量比较重要,三个门的数据大多来源于它,重点记住这个变量数据的来源。这个变量是这个方法的纲领。抓住这个纲领,数据处理流程自然就清晰可见了。

        if (input ==null || input.length() == 0)

            throw new IllegalArgumentException("Invalid input: not set or 0 length");

        // 输入数据权重矩阵,行数:输入神经元数,列数:输出神经元个数*4

        INDArray inputWeights = originalInputWeights;

        // 本神经元前一时间步的输出

        INDArray prevOutputActivations =originalPrevOutputActivations;

        // 判断输入是否为3维。LSTM的输入结构一般都是3维:[miniBatchSize,inputSize,timeSeriesLength]

        boolean is2dInput =input.rank() < 3;

        // 获取时间序列长度,就是3维中最后一维的长度

        int timeSeriesLength = (is2dInput ? 1 :input.size(2));

        // 隐藏层神经元个数,用户自定义

        int hiddenLayerSize =recurrentWeights.size(0);

        // 小批量每批数据量

        int miniBatchSize =input.size(0);

 

        INDArray prevMemCellState;

        if (originalPrevMemCellState ==null) {

            // 本神经元前一时间步记忆状态,就是LSTM结构图中的Ct-1,初始化时是0

            prevMemCellState = Nd4j.create(new int[] { miniBatchSize,hiddenLayerSize}, 'f');

        } else {

            prevMemCellState = originalPrevMemCellState.dup('f');

        }

 

        // 循环权重中输入、遗忘、输出三个门的权重。因为DL4J中有的LSTM支持窥视孔,所以recurrentWeights中除了三个门的权重,还有窥视孔的权重。

        INDArray recurrentWeightsIFOG =recurrentWeights

                        .get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 *hiddenLayerSize)).dup('f');

 

 

        if (conf.isUseDropConnect() &&training && conf.getLayer().getDropOut() > 0) {

            inputWeights = Dropout.applyDropConnect(layer,inputWeightKey);

        }

 

        INDArray wFFTranspose = null;

        INDArray wOOTranspose = null;

        INDArray wGGTranspose = null;

        // 是否支持窥视孔

        if (hasPeepholeConnections) {

            // 下面三个是窥视孔的三个门权重

            wFFTranspose = recurrentWeights

                            .get(NDArrayIndex.all(), interval(4 *hiddenLayerSize, 4 * hiddenLayerSize + 1))

                            .transpose();

            wOOTranspose = recurrentWeights

                            .get(NDArrayIndex.all(), interval(4 *hiddenLayerSize + 1, 4 * hiddenLayerSize + 2))

                            .transpose();

            wGGTranspose = recurrentWeights

                            .get(NDArrayIndex.all(), interval(4 *hidd

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值