DL4J源码阅读(六):LSTM信号前传处理流程

1.1 GravesLSTM参数个数公式

        GravesLSTMParamInitializerpublic int numParams(Layer l)方法的参数个数计算公式如下:

        int nParams = nLast * (4 * nL) //"input" weights

              + nL * (4 * nL + 3) //recurrent weights

              + 4 * nL; //bias

    想要搞清楚这个公式的由来,就要知道LSTM的具体结构。下图是LSTM的结构图:

 DL4J源码阅读(六):LSTM信号前传处理流程

 

        遗忘门、输入门、输出门对应的公式如下:

 DL4J源码阅读(六):LSTM信号前传处理流程

DL4J源码阅读(六):LSTM信号前传处理流程

DL4J源码阅读(六):LSTM信号前传处理流程

DL4J源码阅读(六):LSTM信号前传处理流程

        从上述中公式中可以看出,输入xt对应四种权重,则计算参数个数公式中输入参数个数是nLast * (4 * nL) //"input" weights。前一时间步的输出ht-1对应四种权重。由于这个类是GravesLSTM,看源码注释中标明是实现了peephole connections。其结构如图:

 DL4J源码阅读(六):LSTM信号前传处理流程

        所以循环参数个数是nL * (4 * nL + 3) //recurrent weights。这里的+3就是多出来的peephole connections。而LSTM是没有实现peephole connectionsLSTMParamInitializerpublic int numParams(Layer l)方法的参数个数计算公式如下:        
         int nParams = nLast * (4 * nL) //"input" weights

                + nL * (4 * nL) //recurrent weights

                + 4 * nL; //bias

 

1.2 LSTM神经元数据处理流程解读

        DL4J源码中LSTMHelpersstatic public FwdPassReturn activateHelper()方法集中了LSTM神经元数据处理的主要流程。下面中文注释是对这个方法源码的解读。其中ifogActivations变量比较重要,三个门的数据大多来源于它,重点记住这个变量数据的来源。这个变量是这个方法的纲领。抓住这个纲领,数据处理流程自然就清晰可见了。

        if (input ==null || input.length() == 0)

            throw new IllegalArgumentException("Invalid input: not set or 0 length");

        // 输入数据权重矩阵,行数:输入神经元数,列数:输出神经元个数*4

        INDArray inputWeights = originalInputWeights;

        // 本神经元前一时间步的输出

        INDArray prevOutputActivations =originalPrevOutputActivations;

        // 判断输入是否为3维。LSTM的输入结构一般都是3维:[miniBatchSize,inputSize,timeSeriesLength]

        boolean is2dInput =input.rank() < 3;

        // 获取时间序列长度,就是3维中最后一维的长度

        int timeSeriesLength = (is2dInput ? 1 :input.size(2));

        // 隐藏层神经元个数,用户自定义

        int hiddenLayerSize =recurrentWeights.size(0);

        // 小批量每批数据量

        int miniBatchSize =input.size(0);

 

        INDArray prevMemCellState;

        if (originalPrevMemCellState ==null) {

            // 本神经元前一时间步记忆状态,就是LSTM结构图中的Ct-1,初始化时是0

            prevMemCellState = Nd4j.create(new int[] {miniBatchSize,hiddenLayerSize}, 'f');

        } else {

            prevMemCellState = originalPrevMemCellState.dup('f');

        }

 

        // 循环权重中输入、遗忘、输出三个门的权重。因为DL4J中有的LSTM支持窥视孔,所以recurrentWeights中除了三个门的权重,还有窥视孔的权重。

        INDArray recurrentWeightsIFOG =recurrentWeights

                        .get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 *hiddenLayerSize)).dup('f');

 

 

        if (conf.isUseDropConnect() &&training && conf.getLayer().getDropOut() > 0) {

            inputWeights = Dropout.applyDropConnect(layer,inputWeightKey);

        }

 

        INDArray wFFTranspose = null;

        INDArray wOOTranspose = null;

        INDArray wGGTranspose = null;

        // 是否支持窥视孔

        if (hasPeepholeConnections) {

            // 下面三个是窥视孔的三个门权重

            wFFTranspose = recurrentWeights

                            .get(NDArrayIndex.all(), interval(4 *hiddenLayerSize, 4 * hiddenLayerSize + 1))

                            .transpose();

            wOOTranspose = recurrentWeights

                            .get(NDArrayIndex.all(), interval(4 *hiddenLayerSize + 1, 4 * hiddenLayerSize + 2))

                            .transpose();

            wGGTranspose = recurrentWeights

                            .get(NDArrayIndex.all(), interval(4 *hiddenLayerSize + 2, 4 * hiddenLayerSize + 3))

                            .transpose();

 

            if (timeSeriesLength > 1 ||forBackprop) {

                // toMmulCompatible方法作用可以看它的注释。流程中一个矩阵可能会有多次相乘操作。一个矩阵如果它的值在内存中是相邻的,相乘时就不需要拷贝一份。这个方法是判断矩阵是否满足其值在内存中是相邻的条件。如果满足就返回,不满足就拷贝,使其满足。这样多次相乘,最多拷贝一次,省去多次拷贝的开销。

                wFFTranspose = Shape.toMmulCompatible(wFFTranspose);

                wOOTranspose = Shape.toMmulCompatible(wOOTranspose);

                wGGTranspose = Shape.toMmulCompatible(wGGTranspose);

            }

        }

 

        //Allocate arrays for activations:

        boolean sigmoidGates =gateActivationFn instanceof ActivationSigmoid;

        // 获取本层**函数,用户一般定义为tanh

        IActivation afn = layer.layerConf().getActivationFn();

        INDArray outputActivations =null;

 

        FwdPassReturn toReturn = new FwdPassReturn();

        if (forBackprop) {

            toReturn.fwdPassOutputAsArrays =new INDArray[timeSeriesLength];

            toReturn.memCellState =new INDArray[timeSeriesLength];

            toReturn.memCellActivations =new INDArray[timeSeriesLength];

            toReturn.iz =new INDArray[timeSeriesLength];

            toReturn.ia =new INDArray[timeSeriesLength];

            toReturn.fa =new INDArray[timeSeriesLength];

            toReturn.oa =new INDArray[timeSeriesLength];

            toReturn.ga =new INDArray[timeSeriesLength];

            if (!sigmoidGates) {

                toReturn.fz =new INDArray[timeSeriesLength];

                toReturn.oz =new INDArray[timeSeriesLength];

                toReturn.gz =new INDArray[timeSeriesLength];

            }

 

            if (cacheMode != CacheMode.NONE) {

                try (MemoryWorkspacews = Nd4j.getWorkspaceManager()

                                .getWorkspaceForCurrentThread(ComputationGraph.workspaceCache).notifyScopeBorrowed()) {

                    outputActivations = Nd4j.create(new int[] {miniBatchSize,hiddenLayerSize, timeSeriesLength}, 'f'); //F order to keep time steps together

                    toReturn.fwdPassOutput =outputActivations;

                }

            }

        } else {

            // 初始化输出数组,即ht

            outputActivations = Nd4j.create(new int[] {miniBatchSize,hiddenLayerSize, timeSeriesLength}, 'f'); //F order to keep time steps together

            toReturn.fwdPassOutput =outputActivations;

        }

 

        Level1 l1BLAS = Nd4j.getBlasWrapper().level1();

 

        //Input validation: check input data matches nIn

        if (input.size(1) !=inputWeights.size(0)) {

            throw new DL4JInvalidInputException("Received input with size(1) = " +input.size(1)

                            + " (input array shape = " + Arrays.toString(input.shape())

                            + "); input.size(1) must match layer nIn size (nIn = " +inputWeights.size(0) + ")");

        }

        //Input validation: check that if past state is provided, that it has same

        //These can be different if user forgets to call rnnClearPreviousState() between calls of rnnTimeStep

        if (prevOutputActivations !=null && prevOutputActivations.size(0) != input.size(0)) {

            throw new DL4JInvalidInputException("Previous activations (stored state) number of examples = "

                            + prevOutputActivations.size(0) +" but input array number of examples = " +input.size(0)

                            + ". Possible cause: using rnnTimeStep() without calling"

                            + " rnnClearPreviousState() between different sequences?");

        }

 

        // 初始化前一时间步输出数组

        if (prevOutputActivations ==null) {

            prevOutputActivations = Nd4j.zeros(new int[] {miniBatchSize,hiddenLayerSize});

        }

 

        if (helper !=null) {

            FwdPassReturn ret = helper.activate(layer, conf, gateActivationFn, input, recurrentWeights, inputWeights,

                            biases, training, prevOutputActivations,prevMemCellState, forBackprop, forwards,

                            inputWeightKey,maskArray, hasPeepholeConnections);

            if (ret !=null) {

                return ret;

            }

        }

         // 按时间序列长度循环处理

        for (int iTimeIndex = 0;iTimeIndex < timeSeriesLength; iTimeIndex++) {

            int time =iTimeIndex;

 

            if (!forwards) {

                time = timeSeriesLength - iTimeIndex - 1;

            }

 

            // 获取输入数据矩阵的第time列。就是将一列作为一次输入。转置了,成为了一个行矩阵。

            INDArray miniBatchData = (is2dInput ?input : input.tensorAlongDimension(time, 1, 0));//[Expected shape: [m,nIn]. Also deals with edge case of T=1, with 'time series' data of shape [m,nIn], equiv. to [m,nIn,1]

            miniBatchData = Shape.toMmulCompatible(miniBatchData);

 

            // if we're using cache here - let's create ifogActivations within cache workspace, so all views from this array will be valid in cache

            if (cacheMode != CacheMode.NONE)

                Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                .notifyScopeBorrowed();

 

            //Calculate activations for: network input + forget, output, input modulation gates. Next 3 lines are first part of those

            // 输入矩阵*输入权限矩阵。从上面的LSTM的结构图和计算公式可以看出,输入Xt参与了三门的计算。这里将和三门相关的输入计算一次算出了。到后面再分别截取。ifogActivations是一个行向量。

            INDArray ifogActivations =miniBatchData.mmul(inputWeights);//Shape: [miniBatch,4*layerSize]

 

            if (cacheMode != CacheMode.NONE)

                Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                .notifyScopeLeft();

            // 前输出*循环权重,由于最后参数是1.0ifogActivations加上乘积作为结果。从上面的LSTM的结构图和计算公式可以看出,输入前输出ht-1参与了三门的计算。这里将和三门相关的输入计算一次算出了。到后面再分别截取。            

            Nd4j.gemm(prevOutputActivations,recurrentWeightsIFOG, ifogActivations, false,false, 1.0, 1.0);

            // 再加上偏移

            ifogActivations.addiRowVector(biases);

            // ifogActivations中截取出输入**

            INDArray inputActivations =

                            ifogActivations.get(NDArrayIndex.all(), NDArrayIndex.interval(0,hiddenLayerSize));

            if (forBackprop) {

                if (cacheMode != CacheMode.NONE)

                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                    .notifyScopeBorrowed();

 

                toReturn.iz[time] =inputActivations.dup('f');

 

                if (cacheMode != CacheMode.NONE)

                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                    .notifyScopeLeft();

            }

            // 使用本层的**函数算出**数据。上面已经将**函数赋给了afn变量,这里应该直接使用afn

            layer.layerConf().getActivationFn().getActivation(inputActivations,training);

            if (forBackprop)

                toReturn.ia[time] =inputActivations;

            // ifogActivations中截取出遗忘门数据

            INDArray forgetGateActivations =ifogActivations.get(NDArrayIndex.all(),

                            NDArrayIndex.interval(hiddenLayerSize, 2 *hiddenLayerSize));

            if (hasPeepholeConnections) {

                // 遗忘门的窥视孔

                INDArray pmcellWFF =prevMemCellState.dup('f').muliRowVector(wFFTranspose);

                // 遗忘门数据和窥视孔数据相加

                l1BLAS.axpy(pmcellWFF.length(), 1.0,pmcellWFF, forgetGateActivations); //y = a*x + y i.e., forgetGateActivations.addi(pmcellWFF)

            }

            //Above line: treats matrix as a vector. Can only do this because we're sure both pwcelWFF and forgetGateACtivations are f order, offset 0 and have same strides

            if (forBackprop && !sigmoidGates) {

                if (cacheMode != CacheMode.NONE)

                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                    .notifyScopeBorrowed();

 

                toReturn.fz[time] =forgetGateActivations.dup('f');//Forget gate pre-out (z)

 

                if (cacheMode != CacheMode.NONE)

                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                    .notifyScopeLeft();

            }

            // 使用**函数计算遗忘门数据。**函数默认是sigmoid。到这来遗忘门就计算完了。整个计算过程和上面给出的LSTM的结构图还是非常吻合的。

下面的输入门和输出门流程和遗忘门差不多。我只在不同的地方标注一下。

            gateActivationFn.getActivation(forgetGateActivations,training);

 

            if (forBackprop)

                toReturn.fa[time] =forgetGateActivations;


            INDArray inputModGateActivations =ifogActivations.get(NDArrayIndex.all(),

                            NDArrayIndex.interval(3 * hiddenLayerSize, 4 * hiddenLayerSize));

            if (hasPeepholeConnections) {

                INDArray pmcellWGG =prevMemCellState.dup('f').muliRowVector(wGGTranspose);

                l1BLAS.axpy(pmcellWGG.length(), 1.0,pmcellWGG, inputModGateActivations); //inputModGateActivations.addi(pmcellWGG)

            }

            if (forBackprop && !sigmoidGates) {

                if (cacheMode != CacheMode.NONE)

                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                    .notifyScopeBorrowed();

 

                toReturn.gz[time] =inputModGateActivations.dup('f');//Input modulation gate pre-out (z)

 

                if (cacheMode != CacheMode.NONE)

                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                    .notifyScopeLeft();

            }

            gateActivationFn.getActivation(inputModGateActivations,training);

            if (forBackprop)

                toReturn.ga[time] =inputModGateActivations;

 

            //Memory cell state

            INDArray currentMemoryCellState;

            INDArray inputModMulInput;

            if (forBackprop) {

                if (cacheMode != CacheMode.NONE)

                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                    .notifyScopeBorrowed();

 

                currentMemoryCellState =prevMemCellState.dup('f').muli(forgetGateActivations);

 

                if (cacheMode != CacheMode.NONE)

                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                    .notifyScopeLeft();

 

                // this variable isn't stored in cache

                inputModMulInput = inputModGateActivations.dup('f').muli(inputActivations);

            } else {

                // 遗忘门*前状态,虽然变量名是currentMemoryCellState,但目前还不是当前状态,而是一个中间值。

                currentMemoryCellState =forgetGateActivations.muli(prevMemCellState);

                // 输入门最终数据。看上面给的LSTM结构图和公式,还缺一个tanh的**。其实inputActivations在前面的语句layer.layerConf().getActivationFn().getActivation(inputActivations, training)中使用了tanh的**。因为层的**函数一般都定位为tanh

                inputModMulInput = inputModGateActivations.muli(inputActivations);

            }

            // 这里得到了真正的当前状态

            l1BLAS.axpy(currentMemoryCellState.length(), 1.0,inputModMulInput, currentMemoryCellState); //currentMemoryCellState.addi(inputModMulInput)

            // 不知道细心的你发现没?三门的权重没有按遗忘、输入、输出顺序排列,而是遗忘、输出、输入顺序。所以导致了这个输出截取时,是2 * hiddenLayerSize, 3 * hiddenLayerSize。而上面输入截取是3 * hiddenLayerSize, 4 * hiddenLayerSize。其实数据本身没有什么差别,哪一段当哪个门用都可以。按顺序使用,代码可读性更好,更易理解。

            INDArray outputGateActivations =ifogActivations.get(NDArrayIndex.all(),

                            NDArrayIndex.interval(2 * hiddenLayerSize, 3 * hiddenLayerSize));

            if (hasPeepholeConnections) {

                INDArray pmcellWOO =currentMemoryCellState.dup('f').muliRowVector(wOOTranspose);

                l1BLAS.axpy(pmcellWOO.length(), 1.0,pmcellWOO, outputGateActivations); //outputGateActivations.addi(pmcellWOO)

            }

            if (forBackprop && !sigmoidGates) {

                if (cacheMode != CacheMode.NONE)

                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                    .notifyScopeBorrowed();

 

                toReturn.oz[time] =outputGateActivations.dup('f');//Output gate activations

 

                if (cacheMode != CacheMode.NONE)

                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                    .notifyScopeLeft();

            }

            gateActivationFn.getActivation(outputGateActivations,training);

            if (forBackprop)

                toReturn.oa[time] =outputGateActivations;

 

 

            ////////////// same as with iFogActivations - if we use cache, let's create this array right there

            if (cacheMode != CacheMode.NONE)

                Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                .notifyScopeBorrowed();

 

            //LSTM unit outputs:

            INDArray currMemoryCellActivation =afn.getActivation(currentMemoryCellState.dup('f'),training);

 

 

            if (cacheMode != CacheMode.NONE)

                Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                .notifyScopeLeft();

            ///////////////////

 

            INDArray currHiddenUnitActivations;

            if (forBackprop) {

                if (cacheMode != CacheMode.NONE)

                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                    .notifyScopeBorrowed();

                // 这个是算出当前神经元的输出ht。与上面给出的公式也是吻合的。

                currHiddenUnitActivations =currMemoryCellActivation.dup('f').muli(outputGateActivations);//Expected shape: [m,hiddenLayerSize]

 

                if (cacheMode != CacheMode.NONE)

                    Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)

                                    .notifyScopeLeft();

            } else {

                currHiddenUnitActivations =currMemoryCellActivation.muli(outputGateActivations);//Expected shape: [m,hiddenLayerSize]

            }

 

            if (maskArray !=null) {

                //Mask array is present: bidirectional RNN -> need to zero out these activations to avoid

                // incorrectly using activations from masked time steps (i.e., want 0 initialization in both directions)

                //We *also* need to apply this to the memory cells, as they are carried forward

                //Mask array has shape [minibatch, timeSeriesLength] -> get column

                INDArray timeStepMaskColumn =maskArray.getColumn(time);

                currHiddenUnitActivations.muliColumnVector(timeStepMaskColumn);

                currentMemoryCellState.muliColumnVector(timeStepMaskColumn);

            }

 

            if (forBackprop) {

                toReturn.fwdPassOutputAsArrays[time] =currHiddenUnitActivations;

                toReturn.memCellState[time] =currentMemoryCellState;

                toReturn.memCellActivations[time] =currMemoryCellActivation;

 

                if (cacheMode != CacheMode.NONE) {

                    outputActivations.tensorAlongDimension(time, 1, 0).assign(currHiddenUnitActivations);

                }

            } else {

                outputActivations.tensorAlongDimension(time, 1, 0).assign(currHiddenUnitActivations);

            }

            // 设置前输出的值为当前输出,为下次循环准备数据

            prevOutputActivations = currHiddenUnitActivations;

             // 设置前状态的值为当前状态,为下次循环准备数据

            prevMemCellState = currentMemoryCellState;

 

            // no need to dup here, if that's cache - it's already within Cache workspace

            toReturn.lastAct =currHiddenUnitActivations;

 

            // the same as above, already in cache

            toReturn.lastMemCell =currentMemoryCellState;

        }

 

 

 

        //toReturn.leverageTo(ComputationGraph.workspaceExternal);

 

        toReturn.prevAct =originalPrevOutputActivations;

        toReturn.prevMemCell =originalPrevMemCellState;

 

        return toReturn;