机器学习方法篇(7)------LSTM公式推导

● 每周一言

养成习惯很难,丢掉习惯也很难。

导语

上篇对循环神经网络RNN进行了公式推导,并提到RNN的梯度消失问题,容易将较早之前的有用信息“忘却”。而本篇要讲的LSTM模型,通过增加“门”来有效弥补朴素RNN的缺陷。那么,LSTM模型具体是如何实现的?

LSTM模型

LSTM(Long Short-Term Memory)长短期记忆网络,是一种升级版RNN模型,模型结构如下图所示:

机器学习方法篇(7)------LSTM公式推导

图中方框我们称为记忆单元,其中实线箭头代表当前时刻的信息传递,虚线箭头表示上一时刻的信息传递。从结构图中我们看出,LSTM模型共增加了三个门:输入门、遗忘门和输出门。进入block的箭头代表输入,而出去的箭头代表输出。LSTM的前向传播共公式如下:
机器学习方法篇(7)------LSTM公式推导

图中所有带h的权重矩阵均代表一种泛指,为LSTM的各种变种做准备,表示任意一条从上一时刻指向当前时刻的边,本文暂不考虑。与上篇公式类似,a代表汇集计算结果,b代表**计算结果,Wil代表输入数据与输入门之间的权重矩阵,Wcl代表上一时刻Cell状态与输入门之间的权重矩阵,WiΦ代表输入数据与遗忘门之间的权重矩阵,WcΦ代表上一时刻Cell状态与遗忘门之间的权重矩阵,Wiω代表输入数据与输出门之间的权重矩阵,Wcω代表Cell状态与输出门之间的权重矩阵,Wic代表输入层原有的权重矩阵。需要注意的是,图中Cell一栏描述的是从下方输入到中间Cell输出的整个传播过程。

和朴素RNN的推导一样,有了前向传播公式,我们就能逐个写出LSTM网络中各个参数矩阵的梯度计算公式。首先,由于输出门不牵扯时间维度,我们可以直接写出输出门WiωWcω的迭代公式,如下图:

机器学习方法篇(7)------LSTM公式推导

遗忘门的权重矩阵WiΦ也可以直接给出,如下图:
机器学习方法篇(7)------LSTM公式推导

而对于遗忘门的权重矩阵WcΦ,由于是和上一时刻Cell状态做汇集计算,残差除了来自当前Cell,还来自下一时刻的Cell,因此需要写出下一时刻Cell传播至本时刻遗忘门的时间维度前向传播公式,如下图:
机器学习方法篇(7)------LSTM公式推导

有了上面的公式,我们就能完整写出WcΦ的梯度公式了。如下图所示(如果对这个时间维度前向公式不理解,可以参考上一篇我对朴素RNN的公式推导过程):
机器学习方法篇(7)------LSTM公式推导

请注意,上图中L”和前面的L’不一样,这里只是为了式子简洁。

推完遗忘门公式,就可以此类推输入门与Cell的公式。其中输入门基本与遗忘门的推法一样,残差都是来自本时刻和下一时刻Cell。而Cell的残差则来自三个地方:输出层、输出门和下一时刻Cell。其中输出层和输出门残差可直接写出;而下一时刻Cell的残差,我们只要写出对应的时间维度前向传播公式便可写出。由于时间关系,这里就不详细推导遗忘门和Cell的梯度公式了,各位若有兴趣可自行继续推导。

机器学习方法篇(7)------LSTM公式推导

相比于朴素RNN模型,LSTM模型更为复杂,且可调整和变化的地方也更多。比如:增加peephole将Cell状态连接到每个门,变体模型Gated Recurrent Unit (GRU),以及后面出现的Attention模型等。LSTM模型在语音识别、图像识别、手写识别、以及预测疾病、点击率和股票等众多领域中都发挥着惊人的效果,是目前最火的神经网络模型之一。敬请期待下节。

对本文推导有任何问题和疑问,欢迎留言交流。

结语

感谢各位的耐心阅读,后续文章于每周日奉上,敬请期待。欢迎大家关注小斗公众号 对半独白

机器学习方法篇(7)------LSTM公式推导