DL4J源码阅读(六):LSTM信号前传处理流程
1.1 GravesLSTM参数个数公式
GravesLSTMParamInitializer中 public int numParams(Layer l)方法的参数个数计算公式如下:
int nParams = nLast * (4 * nL) //"input" weights
+ nL * (4 * nL + 3) //recurrent weights
+ 4 * nL; //bias
想要搞清楚这个公式的由来,就要知道LSTM的具体结构。下图是LSTM的结构图:
遗忘门、输入门、输出门对应的公式如下:
从上述中公式中可以看出,输入xt对应四种权重,则计算参数个数公式中输入参数个数是nLast * (4 * nL) //"input" weights。前一时间步的输出ht-1对应四种权重。由于这个类是GravesLSTM,看源码注释中标明是实现了peephole connections。其结构如图:
所以循环参数个数是nL * (4 * nL + 3) //recurrent weights。这里的+3就是多出来的peephole connections。而LSTM是没有实现peephole
connections,LSTMParamInitializer中public int numParams(Layer l)方法的参数个数计算公式如下:
int nParams = nLast * (4 * nL) //"input" weights
+ nL * (4 * nL) //recurrent weights
+ 4 * nL; //bias
1.2 LSTM神经元数据处理流程解读
DL4J源码中LSTMHelpers的static public FwdPassReturn activateHelper()方法集中了LSTM神经元数据处理的主要流程。下面中文注释是对这个方法源码的解读。其中ifogActivations变量比较重要,三个门的数据大多来源于它,重点记住这个变量数据的来源。这个变量是这个方法的纲领。抓住这个纲领,数据处理流程自然就清晰可见了。
if (input ==null || input.length() == 0)
throw new IllegalArgumentException("Invalid input: not set or 0 length");
// 输入数据权重矩阵,行数:输入神经元数,列数:输出神经元个数*4
INDArray inputWeights = originalInputWeights;
// 本神经元前一时间步的输出
INDArray prevOutputActivations =originalPrevOutputActivations;
// 判断输入是否为3维。LSTM的输入结构一般都是3维:[miniBatchSize,inputSize,timeSeriesLength]
boolean is2dInput =input.rank() < 3;
// 获取时间序列长度,就是3维中最后一维的长度
int timeSeriesLength = (is2dInput ? 1 :input.size(2));
// 隐藏层神经元个数,用户自定义
int hiddenLayerSize =recurrentWeights.size(0);
// 小批量每批数据量
int miniBatchSize =input.size(0);
INDArray prevMemCellState;
if (originalPrevMemCellState ==null) {
// 本神经元前一时间步记忆状态,就是LSTM结构图中的Ct-1,初始化时是0
prevMemCellState = Nd4j.create(new int[] {miniBatchSize,hiddenLayerSize}, 'f');
} else {
prevMemCellState = originalPrevMemCellState.dup('f');
}
// 循环权重中输入、遗忘、输出三个门的权重。因为DL4J中有的LSTM支持窥视孔,所以recurrentWeights中除了三个门的权重,还有窥视孔的权重。
INDArray recurrentWeightsIFOG =recurrentWeights
.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 *hiddenLayerSize)).dup('f');
if (conf.isUseDropConnect() &&training && conf.getLayer().getDropOut() > 0) {
inputWeights = Dropout.applyDropConnect(layer,inputWeightKey);
}
INDArray wFFTranspose = null;
INDArray wOOTranspose = null;
INDArray wGGTranspose = null;
// 是否支持窥视孔
if (hasPeepholeConnections) {
// 下面三个是窥视孔的三个门权重
wFFTranspose = recurrentWeights
.get(NDArrayIndex.all(), interval(4 *hiddenLayerSize, 4 * hiddenLayerSize + 1))
.transpose();
wOOTranspose = recurrentWeights
.get(NDArrayIndex.all(), interval(4 *hiddenLayerSize + 1, 4 * hiddenLayerSize + 2))
.transpose();
wGGTranspose = recurrentWeights
.get(NDArrayIndex.all(), interval(4 *hiddenLayerSize + 2, 4 * hiddenLayerSize + 3))
.transpose();
if (timeSeriesLength > 1 ||forBackprop) {
// toMmulCompatible方法作用可以看它的注释。流程中一个矩阵可能会有多次相乘操作。一个矩阵如果它的值在内存中是相邻的,相乘时就不需要拷贝一份。这个方法是判断矩阵是否满足其值在内存中是相邻的条件。如果满足就返回,不满足就拷贝,使其满足。这样多次相乘,最多拷贝一次,省去多次拷贝的开销。
wFFTranspose = Shape.toMmulCompatible(wFFTranspose);
wOOTranspose = Shape.toMmulCompatible(wOOTranspose);
wGGTranspose = Shape.toMmulCompatible(wGGTranspose);
}
}
//Allocate arrays for activations:
boolean sigmoidGates =gateActivationFn instanceof ActivationSigmoid;
// 获取本层**函数,用户一般定义为tanh
IActivation afn = layer.layerConf().getActivationFn();
INDArray outputActivations =null;
FwdPassReturn toReturn = new FwdPassReturn();
if (forBackprop) {
toReturn.fwdPassOutputAsArrays =new INDArray[timeSeriesLength];
toReturn.memCellState =new INDArray[timeSeriesLength];
toReturn.memCellActivations =new INDArray[timeSeriesLength];
toReturn.iz =new INDArray[timeSeriesLength];
toReturn.ia =new INDArray[timeSeriesLength];
toReturn.fa =new INDArray[timeSeriesLength];
toReturn.oa =new INDArray[timeSeriesLength];
toReturn.ga =new INDArray[timeSeriesLength];
if (!sigmoidGates) {
toReturn.fz =new INDArray[timeSeriesLength];
toReturn.oz =new INDArray[timeSeriesLength];
toReturn.gz =new INDArray[timeSeriesLength];
}
if (cacheMode != CacheMode.NONE) {
try (MemoryWorkspacews = Nd4j.getWorkspaceManager()
.getWorkspaceForCurrentThread(ComputationGraph.workspaceCache).notifyScopeBorrowed()) {
outputActivations = Nd4j.create(new int[] {miniBatchSize,hiddenLayerSize, timeSeriesLength}, 'f'); //F order to keep time steps together
toReturn.fwdPassOutput =outputActivations;
}
}
} else {
// 初始化输出数组,即ht
outputActivations = Nd4j.create(new int[] {miniBatchSize,hiddenLayerSize, timeSeriesLength}, 'f'); //F order to keep time steps together
toReturn.fwdPassOutput =outputActivations;
}
Level1 l1BLAS = Nd4j.getBlasWrapper().level1();
//Input validation: check input data matches nIn
if (input.size(1) !=inputWeights.size(0)) {
throw new DL4JInvalidInputException("Received input with size(1) = " +input.size(1)
+ " (input array shape = " + Arrays.toString(input.shape())
+ "); input.size(1) must match layer nIn size (nIn = " +inputWeights.size(0) + ")");
}
//Input validation: check that if past state is provided, that it has same
//These can be different if user forgets to call rnnClearPreviousState() between calls of rnnTimeStep
if (prevOutputActivations !=null && prevOutputActivations.size(0) != input.size(0)) {
throw new DL4JInvalidInputException("Previous activations (stored state) number of examples = "
+ prevOutputActivations.size(0) +" but input array number of examples = " +input.size(0)
+ ". Possible cause: using rnnTimeStep() without calling"
+ " rnnClearPreviousState() between different sequences?");
}
// 初始化前一时间步输出数组
if (prevOutputActivations ==null) {
prevOutputActivations = Nd4j.zeros(new int[] {miniBatchSize,hiddenLayerSize});
}
if (helper !=null) {
FwdPassReturn ret = helper.activate(layer, conf, gateActivationFn, input, recurrentWeights, inputWeights,
biases, training, prevOutputActivations,prevMemCellState, forBackprop, forwards,
inputWeightKey,maskArray, hasPeepholeConnections);
if (ret !=null) {
return ret;
}
}
// 按时间序列长度循环处理
for (int iTimeIndex = 0;iTimeIndex < timeSeriesLength; iTimeIndex++) {
int time =iTimeIndex;
if (!forwards) {
time = timeSeriesLength - iTimeIndex - 1;
}
// 获取输入数据矩阵的第time列。就是将一列作为一次输入。转置了,成为了一个行矩阵。
INDArray miniBatchData = (is2dInput ?input : input.tensorAlongDimension(time, 1, 0));//[Expected shape: [m,nIn]. Also deals with edge case of T=1, with 'time series' data of shape [m,nIn], equiv. to [m,nIn,1]
miniBatchData = Shape.toMmulCompatible(miniBatchData);
// if we're using cache here - let's create ifogActivations within cache workspace, so all views from this array will be valid in cache
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeBorrowed();
//Calculate activations for: network input + forget, output, input modulation gates. Next 3 lines are first part of those
// 输入矩阵*输入权限矩阵。从上面的LSTM的结构图和计算公式可以看出,输入Xt参与了三门的计算。这里将和三门相关的输入计算一次算出了。到后面再分别截取。ifogActivations是一个行向量。
INDArray ifogActivations =miniBatchData.mmul(inputWeights);//Shape: [miniBatch,4*layerSize]
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeLeft();
// 前输出*循环权重,由于最后参数是1.0,ifogActivations加上乘积作为结果。从上面的LSTM的结构图和计算公式可以看出,输入前输出ht-1参与了三门的计算。这里将和三门相关的输入计算一次算出了。到后面再分别截取。
Nd4j.gemm(prevOutputActivations,recurrentWeightsIFOG, ifogActivations, false,false, 1.0, 1.0);
// 再加上偏移
ifogActivations.addiRowVector(biases);
// 从ifogActivations中截取出输入**
INDArray inputActivations =
ifogActivations.get(NDArrayIndex.all(), NDArrayIndex.interval(0,hiddenLayerSize));
if (forBackprop) {
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeBorrowed();
toReturn.iz[time] =inputActivations.dup('f');
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeLeft();
}
// 使用本层的**函数算出**数据。上面已经将**函数赋给了afn变量,这里应该直接使用afn。
layer.layerConf().getActivationFn().getActivation(inputActivations,training);
if (forBackprop)
toReturn.ia[time] =inputActivations;
// 从ifogActivations中截取出遗忘门数据
INDArray forgetGateActivations =ifogActivations.get(NDArrayIndex.all(),
NDArrayIndex.interval(hiddenLayerSize, 2 *hiddenLayerSize));
if (hasPeepholeConnections) {
// 遗忘门的窥视孔
INDArray pmcellWFF =prevMemCellState.dup('f').muliRowVector(wFFTranspose);
// 遗忘门数据和窥视孔数据相加
l1BLAS.axpy(pmcellWFF.length(), 1.0,pmcellWFF, forgetGateActivations); //y = a*x + y i.e., forgetGateActivations.addi(pmcellWFF)
}
//Above line: treats matrix as a vector. Can only do this because we're sure both pwcelWFF and forgetGateACtivations are f order, offset 0 and have same strides
if (forBackprop && !sigmoidGates) {
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeBorrowed();
toReturn.fz[time] =forgetGateActivations.dup('f');//Forget gate pre-out (z)
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeLeft();
}
// 使用**函数计算遗忘门数据。**函数默认是sigmoid。到这来遗忘门就计算完了。整个计算过程和上面给出的LSTM的结构图还是非常吻合的。
下面的输入门和输出门流程和遗忘门差不多。我只在不同的地方标注一下。
gateActivationFn.getActivation(forgetGateActivations,training);
if (forBackprop)
toReturn.fa[time] =forgetGateActivations;
INDArray inputModGateActivations =ifogActivations.get(NDArrayIndex.all(),
NDArrayIndex.interval(3 * hiddenLayerSize, 4 * hiddenLayerSize));
if (hasPeepholeConnections) {
INDArray pmcellWGG =prevMemCellState.dup('f').muliRowVector(wGGTranspose);
l1BLAS.axpy(pmcellWGG.length(), 1.0,pmcellWGG, inputModGateActivations); //inputModGateActivations.addi(pmcellWGG)
}
if (forBackprop && !sigmoidGates) {
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeBorrowed();
toReturn.gz[time] =inputModGateActivations.dup('f');//Input modulation gate pre-out (z)
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeLeft();
}
gateActivationFn.getActivation(inputModGateActivations,training);
if (forBackprop)
toReturn.ga[time] =inputModGateActivations;
//Memory cell state
INDArray currentMemoryCellState;
INDArray inputModMulInput;
if (forBackprop) {
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeBorrowed();
currentMemoryCellState =prevMemCellState.dup('f').muli(forgetGateActivations);
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeLeft();
// this variable isn't stored in cache
inputModMulInput = inputModGateActivations.dup('f').muli(inputActivations);
} else {
// 遗忘门*前状态,虽然变量名是currentMemoryCellState,但目前还不是当前状态,而是一个中间值。
currentMemoryCellState =forgetGateActivations.muli(prevMemCellState);
// 输入门最终数据。看上面给的LSTM结构图和公式,还缺一个tanh的**。其实inputActivations在前面的语句layer.layerConf().getActivationFn().getActivation(inputActivations, training)中使用了tanh的**。因为层的**函数一般都定位为tanh。
inputModMulInput = inputModGateActivations.muli(inputActivations);
}
// 这里得到了真正的当前状态
l1BLAS.axpy(currentMemoryCellState.length(), 1.0,inputModMulInput, currentMemoryCellState); //currentMemoryCellState.addi(inputModMulInput)
// 不知道细心的你发现没?三门的权重没有按遗忘、输入、输出顺序排列,而是遗忘、输出、输入顺序。所以导致了这个输出截取时,是2 * hiddenLayerSize, 3 * hiddenLayerSize。而上面输入截取是3 * hiddenLayerSize, 4 * hiddenLayerSize。其实数据本身没有什么差别,哪一段当哪个门用都可以。按顺序使用,代码可读性更好,更易理解。
INDArray outputGateActivations =ifogActivations.get(NDArrayIndex.all(),
NDArrayIndex.interval(2 * hiddenLayerSize, 3 * hiddenLayerSize));
if (hasPeepholeConnections) {
INDArray pmcellWOO =currentMemoryCellState.dup('f').muliRowVector(wOOTranspose);
l1BLAS.axpy(pmcellWOO.length(), 1.0,pmcellWOO, outputGateActivations); //outputGateActivations.addi(pmcellWOO)
}
if (forBackprop && !sigmoidGates) {
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeBorrowed();
toReturn.oz[time] =outputGateActivations.dup('f');//Output gate activations
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeLeft();
}
gateActivationFn.getActivation(outputGateActivations,training);
if (forBackprop)
toReturn.oa[time] =outputGateActivations;
////////////// same as with iFogActivations - if we use cache, let's create this array right there
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeBorrowed();
//LSTM unit outputs:
INDArray currMemoryCellActivation =afn.getActivation(currentMemoryCellState.dup('f'),training);
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeLeft();
///////////////////
INDArray currHiddenUnitActivations;
if (forBackprop) {
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeBorrowed();
// 这个是算出当前神经元的输出ht。与上面给出的公式也是吻合的。
currHiddenUnitActivations =currMemoryCellActivation.dup('f').muli(outputGateActivations);//Expected shape: [m,hiddenLayerSize]
if (cacheMode != CacheMode.NONE)
Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceCache)
.notifyScopeLeft();
} else {
currHiddenUnitActivations =currMemoryCellActivation.muli(outputGateActivations);//Expected shape: [m,hiddenLayerSize]
}
if (maskArray !=null) {
//Mask array is present: bidirectional RNN -> need to zero out these activations to avoid
// incorrectly using activations from masked time steps (i.e., want 0 initialization in both directions)
//We *also* need to apply this to the memory cells, as they are carried forward
//Mask array has shape [minibatch, timeSeriesLength] -> get column
INDArray timeStepMaskColumn =maskArray.getColumn(time);
currHiddenUnitActivations.muliColumnVector(timeStepMaskColumn);
currentMemoryCellState.muliColumnVector(timeStepMaskColumn);
}
if (forBackprop) {
toReturn.fwdPassOutputAsArrays[time] =currHiddenUnitActivations;
toReturn.memCellState[time] =currentMemoryCellState;
toReturn.memCellActivations[time] =currMemoryCellActivation;
if (cacheMode != CacheMode.NONE) {
outputActivations.tensorAlongDimension(time, 1, 0).assign(currHiddenUnitActivations);
}
} else {
outputActivations.tensorAlongDimension(time, 1, 0).assign(currHiddenUnitActivations);
}
// 设置前输出的值为当前输出,为下次循环准备数据。
prevOutputActivations = currHiddenUnitActivations;
// 设置前状态的值为当前状态,为下次循环准备数据。
prevMemCellState = currentMemoryCellState;
// no need to dup here, if that's cache - it's already within Cache workspace
toReturn.lastAct =currHiddenUnitActivations;
// the same as above, already in cache
toReturn.lastMemCell =currentMemoryCellState;
}
//toReturn.leverageTo(ComputationGraph.workspaceExternal);
toReturn.prevAct =originalPrevOutputActivations;
toReturn.prevMemCell =originalPrevMemCellState;
return toReturn;