深入理解RNN与LSTM

循环神经网络(Recurrent Neural Network)基础

在深度学习领域,神经网络已经被用于处理各类数据,如CNN在图像领域的应用,全连接神经网络在分类问题的应用等。随着神经网络在各个领域的渗透,传统以统计机器学习为主的NLP问题,也逐渐开始采用深度学习的方法来解决。如由Google Brain提出的Word2Vec模型,便将传统BoW等统计方法的词向量方法,带入到了以深度学习为基础的Distribution Representation的方法中来,真正地将NLP问题带入了深度学习的练兵场。当然,RNN的模型并非局限于NLP领域,而是为了解决一系列序列化数据的建模问题,如视频、语音等,而文本也只是序列化数据的一种典型案例。

RNN的特征在于,对于每个RNN神经元,其参数始终共享,即对于文本序列,任何一个输入都经过相同的处理,得到一个输出。在传统的全连接神经网络的结构中,神经元之间互不影响,并没有直接联系,神经元与神经元之间相互独立。而在RNN结构中,隐藏层的神经元开始通过一个隐藏状态所相连,通常会被表示为hth_t。在理解RNN与全连接神经网络时,需要对两者的结构加以区分,通常,FCN会采用水平方式进行可视化理解,即每一层的神经元垂直排列,而不同层之间以水平方式排布。但在RNN的模型图中,隐藏层的不同神经元之间通常水平排列,而隐藏层的不同层之间以垂直方式排列,如图所示,在FCN网络中,各层水平布局,隐藏层各神经元相互独立,在RNN中,各层以垂直布局,而水平方向上布局着各神经元。注意:RNN结构图只是为了使得结构直观易理解,而在水平方向上其实每个A都相同,对于每个时间步其都是采用同一个神经元进行前向传播。

深入理解RNN与LSTM

RNN的前向传播

在RNN中,序列数据按照其时间顺序,依次输入到网络中,而时间顺序则表示时间步的概念。在RNN中,隐藏状态极为重要,隐藏状态是连接各隐藏层各神经元的中介值。如上图,在第一层中,在时间步tt,RNN隐藏层神经元得到隐藏状态ht(1)h_{t}^{(1)},在时间步t+1t+1,则接受来自上一个时间步的隐藏层输出ht(1)h_{t}^{(1)},得到新的隐藏状态ht+1(1)h_{t+1}^{(1)}。而从垂直方向上看,各层之间,也通过隐藏状态所连接,对于L1L_1L2L_2L2L_2在水平的时间轴上,各神经元通过隐藏状态ht(2)h_{t}^{(2)}连接,而层间还将接受前一层的ht(1)h_{t}^{(1)}的值来作为xtx_t的值,从而获得到该层新的隐藏状态。因此,RNN是一个在水平方向和垂直方向上,均可扩展的结构(水平方向上只是人为添加的易于理解的状态,在工程实践中不存在水平方向的设置)。

根据RNN的定义,可以简单地给出RNN的前向传播过程:

ht=g(Wxt+Vht1+b)h_t=g\left(Wx_{t}+Vh_{t-1}+b\right)

如上式,对于某一层,WVbW、V、b均为模型需要学习的参数,通过上图RNN结构图的对应,则应为L1L_1层水平方向所有神经元的参数,**同一层的RNN单元参数相同,即参数共享。**若考虑多层RNN,则可将上式改为:

ht[i]=g(W[i]ht[i1]+V[i1]ht1[i]+b[i])h_{t}^{[i]}=g\left(W^{[i]}h_{t}^{[i-1]}+V^{[i-1]}h_{t-1}^{[i]}+b^{[i]}\right)

为了简化研究,下文统一对单层RNN进行讨论。

值得注意的是,单层RNN前向传播可做如下变换:

Wxt+Vht=[WV]×[xtht1]Wx_t+Vh_t=\left[\begin{array}{cc}W&V\end{array}\right]\times\left[\begin{array}{c}x_t\\h_t-1\end{array}\right]

为此,我们不妨将参数进行统一表示:W=[W;V]W=\left[W;V\right],其中[;][\cdot;\cdot]表示拼接操作,则前向传播变为ht=g(W[ht1;xt]+b)h_t=g\left(W[h_{t-1};x_t]^{\top}+b\right)

再获得隐藏状态后,若需要获得每一个时间步的输出,则需要进一步进行线性变换:

ot=Vht+bo,    yt=g(ot)o_t=Vh_t+b_o, \;\;y_t=g(o_t),其中VbV、b为参数,g()g(\cdot)为**函数,如softmax。

深入理解RNN与LSTM

针对单层RNN,可采用上述结构进行描述。

RNN的反向传播

为简化分析,选用RNN的最后时间步的隐藏状态(无输出层)直接作为输出层,即output=ht=g(W[ht1;xt]+b)output=h_t=g\left(W\left[h_{t-1};x_{t}\right]^{\top}+b\right),若为分类问题,则g()g(\cdot)通常为Softmax。定义问题的损失函数为J(θ)=Loss(output,yθ)J(\theta)=Loss\left(output,y|\theta\right),则在进行反向传播时,需要计算WbW、b的梯度,可进行如下推导:

ΔW=J(θ)W=WLoss(output,y)=Loss(output,y)g(W[ht1;x]+b)W=Loss(output,y)g()[ht1;xt]\Delta W=\frac{\partial J(\theta)}{\partial W}=\frac{\partial}{\partial W} Loss(output,y)=Loss(output,y)'\frac{\partial g\left(W\left[h_{t-1};x\right]^{\top}+b\right)}{\partial W}=Loss(output,y)'g(\cdot)'[h_{t-1};x_{t}]

然而,在RNN的反向传播中,不仅需要根据垂直方向进行梯度推导,同时需要根据水平方向,按照时间步进行梯度推导,即RNN中的BPTT(Back Propagation Through Time)反向传播。从公式中也可以看出,在前向传播过程中hth_t是关于WWht1h_{t-1}的函数值,即ht=f(W[ht1;xt]+b)h_{t}=f\left(W\left[h_{t-1};x_{t}\right]^\top + b\right),则hth_t可以进一步进行微分,于是将hth_t关于WW求偏导,以循着时间轴更新t1t-1时刻的WW

ht[ht1;xt][ht1;xt]W=f()W[ht2;xt1]ΔWt1=Loss(output,y)g()f()W[ht2;xt1]\frac{\partial h_t}{\partial [h_{t-1};x_t]}\frac{\partial [h_{t-1};x_t]}{\partial W}=f(\cdot)'W[h_{t-2};x_{t-1}]\Rightarrow \Delta W_{t-1}=Loss(output,y)'g(\cdot)'f(\cdot)'W[h_{t-2};x_{t-1}]

根据反向传播的规则,每个在当前时间步tt应向前追溯直到t0t_0,计算梯度并更新参数,而在RNN中时间步中的WW参数被所有步共享,因此梯度是对同一个参数计算,为此可以将梯度作求和,一次性更新至WW,如图每个箭头表示一次梯度计算,则在t3t_3时刻计算梯度时,不仅需要直接计算当前时刻的梯度,还仍需根据时间轴,分别计算t2,t1t_2,t_1时刻的梯度。

注:本推导在假设RNN仅使用一个输出,即最后一个时间步的输出为最终输出,而RNN在每个时间步均有输出,若考虑多个输出,则损失函数不同,即损失为各时间步损失的总和,而在计算梯度时,需要对每个时间步输出均计算一个输出,即Loss=tJ(θ)Loss=\sum\limits^t J(\theta).

深入理解RNN与LSTM

则在t1t\rightarrow1的过程中,WW更新的梯度为ΔW=(Loss(output,y)g()[ht1;xt])+k=1T1((t=kTLoss(output,y)g()f()W)[hk1;xk])\Delta W=\left(Loss(output,y)'g(\cdot)'[h_{t-1};x_t]\right)+\sum\limits_{k=1}^{T-1}\left(\left(\prod \limits_{t=k}^{T}Loss(output,y)'g(\cdot)'f\left(\cdot\right)'W\right)[h_{k-1};x_{k}]\right)

对于偏置bb采用相同方式推导,此处不再重复推导。

注意:此处和后文若无特殊说明,均只讨论单层RNN,多层RNN则将RNN单元视为FCN中层即可。

RNN的梯度弥散与爆炸

根据上节的推导,可知,在进行BPTT时,RNN单元的反向传播梯度如下:

ΔWt=(Loss(output,y)g()tTtf(W[ht1;xt]+b))[ht1;xt]\Delta W_{t}=\left(Loss(output,y)'g(\cdot)'\prod \limits_{t}^{T-t}f\left(W\left[h_{t-1};x_t\right]^\top + b\right)'\right)\left[h_{t-1};x_{t}\right]

若**函数f()f(\cdot)采用tanhtanhsigmoidsigmoid,图像如图:

深入理解RNN与LSTM

对**函数求导,当f(x)=11+exf(x)=\frac{1}{1+e^{-x}}时,f(x)=(1+ex)2ex=ex(1+ex)2=ex+1(1+ex)2+1(1+ex)2=f(x)+f(x)2=f(x)(1f(x))f(x)'=-(1+e^{-x})^{-2}e^{-x}=-\frac{e^{-x}}{(1+e^{-x})^2}=-\frac{e^{-x}+1}{(1+e^{-x})^2}+\frac{1}{(1+e^{-x})^2}=-f(x)+f(x)^2=f(x)\left(1-f\left(x\right)\right)

f(x)=exexex+exf(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}}时,

f(x)=ex+exex+ex+(exex)(1ex+ex)2(exex)(1)=1f(x)2\begin{aligned}f(x)'&=\frac{e^x+e^{-x}}{e^x+e^{-x}}+(e^x-e^{-x})\cdot\left(\frac{1}{e^x+e^{-x}}\right)^2\cdot\left(e^x-e^{-x}\right)(-1)\\&=1-f(x)^2\end{aligned}

sigmoidsigmoidtanhtanh导数图像如下图所示:

深入理解RNN与LSTM

从图像可以看出,在**函数的两端,导数均介接近于0,根据上述RNN梯度的推导,假设当前处于最后一个时间步tt,则在向前BPTT时,会得出f()\prod f(\cdot)'的计算,当f(x)f(x)值接近于两端时,则其梯度异常接近于0,并且sigmoidsigmoid导数最大值才为14\frac{1}{4},多个接近于0的数相乘,将导致梯度呈指数下降趋势,接近于0,导致梯度弥散。随着序列的变长,f()\prod f(\cdot)'的值越小,这便说明,RNN不具备长期记忆,而只具备短期记忆。

由于梯度弥散,导致在序列长度很长时,无法在较后的时间步中,按照梯度更新较前时间步的WW,导致无法根据后续序列来修改前向序列的参数,使得前向序列无法很好地做特征提取,使得在长时间步过后,模型将无法再获取有效的前向序列记忆信息

梯度弥散,在RNN属于重要问题,为此便提出了以LSTM、GRU等结构的变种,来解决RNN短期记忆的瓶颈。同样的,根据上述梯度的推导,梯度中W\prod W将会导致参数累乘,若初始参数较大时,则较大数相乘,将导致梯度爆炸,然而梯度爆炸相对于梯度弥散较容易解决,通常加入梯度裁剪即可一定程度缓解。

长短期记忆网络(Long Short Term Memory)

前面说到,RNN单元在面对长序列数据时,很容易便遭遇梯度弥散,使得RNN只具备短期记忆,即RNN面对长序列数据,仅可获取较近的序列的信息,而对较早期的序列不具备记忆功能,从而丢失信息。为此,为解决该类问题,便提出了LSTM结构,其核心关键在于:

  1. 提出了门机制:遗忘门、输入门、输出门
  2. 细胞状态:在RNN中只有隐藏状态的传播,而在LSTM中,引入了细胞状态。

LSTM的前向传播

如下图,为三个LSTM单元的连结,其中相较于传统RNN单元,其多了上下两条轴,分别用于承载细胞状态CC及隐藏状态hh的信息流动,而其中σ\sigma则被称为门,通过乘运算于和运算实现数据的合并于过滤。

为更好地比较LSTM与RNN的区别,再此将RNN前向传播记录如下:

ht=f(Uht1+Wxt+bh)ot=Vht+boyt=g(ot)\begin{aligned}h_t&=f(Uh_{t-1}+Wx_t+b_h)\\o_t&=Vh_t+b_o\\y_t&=g(o_t)\end{aligned}

深入理解RNN与LSTM

紧接着,对LSTM的门进行定义,其均为:

gatef,i,o(ht1,xt)=σ(Uht1+Wxt+b)gate_{f,i,o}(h_{t-1}, x_t)=\sigma(Uh_{t-1}+Wx_t+b)

其中,f,i,of,i,o分别表示遗忘门、输入门、输出们,对应地,U,W,bU,W,b在不同门中,也应为不同的参数。为此,可卸除LSTM详细的前向传播过程。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IVanN6eD-1582277223298)(LSTM-Cell.eps)]

如图中各σ\sigma,则表示各门,其与×,+\times,+运算做到了信息过滤和叠加。

在遗忘门:ft=σ(Ufht1+Wfxt+bf)f_t=\sigma(U_fh_{t-1}+W_fx_t+b_f),由之前所介绍的sigmoidsigmoid函数可知,其函数值在(0,1)(0,1)范围内。这里可以思考一下计算机中,门电路的思想,在逻辑电路中,分为“与门”,“或门”,“非门”等,对于“与门”,只有当两者均为1时结果为1,同样地对于遗忘门的运算,其输出值为(0,1)(0,1),当进行乘法运算时,是否也能达到信息过滤的效果呢?

结果很显然,当任何一个数乘以0时,其值为0,那么在后续的线性运算过程中其仍然为0,便可表示,其信息被忽略,因为到下一层时,其未产生信息叠加。

深入理解RNN与LSTM

同理,对于输入门,我们有:

it=σ(Uiht1+Wixt+bi)i_t=\sigma(U_ih_{t-1}+W_ix_t+b_i)

深入理解RNN与LSTM

而输入门主要控制对输入的信息进行过滤,即在输入时选择性地抛弃某些信息,而抛弃的信息,即为输入门中输出为0的特征维度。同时,在时间步tt,原输入应为:ht1,xth_{t-1},x_t,按照传统的RNN的前向传播,输入应经过线性变换后进行**,并且**函数通常使用tanhtanh,即:Ct~=tanh(Uxht1+Wxxt+bx)\widetilde{C_t}=tanh(U_xh_{t-1}+W_xx_t+b_x)

上述输入的变化,可以对应RNN的输入过程。

由于加了门机制,则需要对输入的信息,进行过滤,而输入信息在LSTM中包含:细胞状态、隐藏状态、当前时间步输入。其中隐藏状态、当前时间步,已经作为输入经过传统的RNN变换得到Ct~\widetilde{C_t},还剩下细胞状态,因此需要进一步将细胞状态与Ct~\widetilde{C_t}融合,并得到新的细胞状态:

深入理解RNN与LSTM

Ct=ftCt1+itCt~C_t=f_t\odot C_{t-1} + i_t \odot \widetilde{C_t},其中\odot表示element-wise乘积。

在输出门中,同样采用相同的方式得到门概率分布:ot=σ(Uoht1+Woxt+bo)o_t=\sigma(U_oh_{t-1}+W_ox_t+b_o)。输出门的作用在于,对于要输出给下一个时间步的信息,进行一定地过滤,有选择性地保留和去除之前时间步的某些数据。因此,有ht=ottanh(Ct)h_t=o_t\odot tanh(C_t) 。得到hth_t后,便可进一步得到yty_t,其过程与RNN一致。至此,LSTM的前向传播过程即以结束。

深入理解RNN与LSTM

LSTM的结构有效地解决了RNN的短期依赖瓶颈。但是从模型结构可以看出,相较于RNN,LSTM含有更多的参数需要学习,从而导致LSTM的学习速度大大降低。

上述公式推导过程中,同样可以采用拼接的方式,使得W:=[UW]W:=\left[\begin{array}{cc}U&W\end{array}\right],而X:=[ht1xt]X:=\left[\begin{array}{c}h_{t-1}\\x_t\end{array}\right]

对前向传播的过程进行整理,可得:

Ct=σ(Uiht1+Wixt+bi)tanh(Uxht1+Wxxt+bx)+σ(Ufht1+Wfxt+bf)Ct1ht=tanh(Ct)σ(Uoht1+Woxt+bo)yt=f(Vyht+by)\begin{aligned}C_t&=\sigma(U_ih_{t-1}+W_ix_t+b_i)\odot tanh(U_xh_{t-1}+W_xx_t+b_x) + \sigma(U_fh_{t-1}+W_fx_t+b_f) \odot C_{t-1}\\ h_t&=tanh(C_t)\odot \sigma(U_oh_{t-1}+W_ox_t+b_o)\\y_t&=f(V_yh_t+b_y)\end{aligned}