神经网络中的梯度不稳

神经网络中的梯度不稳

一、概述

  深度网络容易出现梯度不稳(梯度消失、梯度爆炸)问题,造成网络学习停滞。

梯度消失

  在深层网络中,如果**函数的导数小于1,根据链式求导法则,靠近输入层的参数的梯度因为乘了很多的小于1的数而越来越小,最终就会趋近于0。例如sigmoid函数,其导数f(x)=f(x)(1f(x))f'(x)=f(x)(1−f(x))的值域为014(0,\frac{1}{4}),极易发生这种情况。
所以梯度消失出现的原因经常是因为网络层次过深,以及**函数选择不当等。

  梯度消失的表现:模型无法从训练数据中获得更新,损失几乎保持不变。

梯度爆炸

  同梯度消失的原因一样,求解损失函数对参数的偏导数时,在梯度的连续乘法中总是遇上很大的绝对值,部分参数的梯度因为乘了很多较大的数而变得非常大,导致模型无法收敛。
  所以梯度爆炸出现的原因也是网络层次过深,或者权值初始化值太大。

  梯度爆炸的表现:模型型不稳定,更新过程中的损失出现显著变化;训练过程中模型损失变成NaN。

二、RNN中梯度不稳的原因

  传统RNN的模型结构如下:

神经网络中的梯度不稳

  为简化说明,假设此处的时间序列长度为3,即在h0h_0给定的情况下,有如下状态:
{h1=Ux1+Wh0+b1o1=Vh1+b2h2=Ux2+Wh1+b1o2=Vh2+b2h3=Ux3+Wh2+b1o3=Vh3+b2\begin{cases} h_1 = Ux_1 + Wh_0 + b_1 \\ o_1 = Vh_1 + b_2 \\ h_2 = Ux_2 + Wh_1 + b_1 \\ o_2 = Vh_2 + b_2 \\ h_3 = Ux_3 + Wh_2 + b_1 \\ o_3 = Vh_3 + b_2 \\ \end{cases}

  在t=3t=3的时刻,损失函数为L3=12(y3o3)2L_3 = \frac{1}{2} (y_3-o_3)^2,中的损失函数为L=t=0TLtL=\sum_{t=0}^T L_t
  对t=3t=3时刻的权重矩阵求偏导,得到:
{L3V=L3o3o3VL3U=L3o3o3h3h3U+L3o3o3h3h3h2h2U+L3o3o3h3h3h2h2h1h1UL3W=L3o3o3h3h3W+L3o3o3h3h3h2h2W+L3o3o3h3h3h2h2h1h1W\begin{cases} \frac{\partial L_3}{\partial V} = \frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial V} \\ \frac{\partial L_3}{\partial U} = \frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial h_3}\frac{\partial h_3}{\partial U} + \frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial h_3}\frac{\partial h_3}{\partial h_2}\frac{\partial h_2}{\partial U} + \frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial h_3}\frac{\partial h_3}{\partial h_2}\frac{\partial h_2}{\partial h_1}\frac{\partial h_1}{\partial U} \\ \frac{\partial L_3}{\partial W} = \frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial h_3}\frac{\partial h_3}{\partial W} + \frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial h_3}\frac{\partial h_3}{\partial h_2}\frac{\partial h_2}{\partial W} + \frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial h_3}\frac{\partial h_3}{\partial h_2}\frac{\partial h_2}{\partial h_1}\frac{\partial h_1}{\partial W} \end{cases}

  可以看出对VV没有长期依赖,而对于WUW、U存在长期依赖。从上式可以推导出一般情况下的偏导公式:
{LtU=k=0tLtototht(j=k+1thjhj1)hkULtW=k=0tLtototht(j=k+1thjhj1)hkW\begin{cases} \frac{\partial L_t}{\partial U} = \sum_{k=0}^t \frac{\partial L_t}{\partial o_t} \frac{\partial o_t}{\partial h_t} (\prod_{j=k+1}^t \frac{\partial h_j}{\partial h_{j-1}}) \frac{\partial h_k}{\partial U} \\ \frac{\partial L_t}{\partial W} = \sum_{k=0}^t \frac{\partial L_t}{\partial o_t} \frac{\partial o_t}{\partial h_t} (\prod_{j=k+1}^t \frac{\partial h_j}{\partial h_{j-1}}) \frac{\partial h_k}{\partial W} \end{cases}
  如果加上**函数,hj=f(Uxj+Wsj1+b1)h_j = f(Ux_j + Ws_{j-1} + b_1),则有
j=k+1thjhj1=j=k+1tfW\prod_{j=k+1}^t \frac{\partial h_j}{\partial h_{j-1}} = \prod_{j=k+1}^t f' W
  其中ftanhsigmodf为\tanh或者sigmod
  **函数tanh及其导数如下图所示:
神经网络中的梯度不稳

  **函数sigmod及其导数如下图所示:
神经网络中的梯度不稳

  由上图可以看出tanh1\tanh' \leq 1,而sigmod0.25sigmod' \leq 0.25。因此如果当WW或者UU也是大于0小于1时,若tt很大,则j=k+1tfW\prod_{j=k+1}^t f' W就会趋于0;反之如果WW或者UU很大时,j=k+1tfW\prod_{j=k+1}^t f' W就会趋于无穷。这就是RNN中梯度消失和梯度爆炸的原因。

三、LSTM解决梯度消失

  梯度消失和爆炸的根本原因是在于j=k+1thjhj1\prod_{j=k+1}^t \frac{\partial h_j}{\partial h_{j-1}}。要消除这种现象,就要把这个在求偏导的过程中去掉,即使得hjhj10\frac{\partial h_j}{\partial h_{j-1}} \approx 0,或者hjhj11\frac{\partial h_j}{\partial h_{j-1}} \approx 1

  LSTM采用门(gate)来控制输入输出。门控单元的相关计算如下:

{ft=σ(Wf[ht1,xt]+bf)it=σ(Wi[ht1,xt]+bi)ot=σ(Wo[ht1,xt]+bo)\begin{cases} f_t = \sigma(W_f·[h_{t-1},x_t]+b_f) \\ i_t = \sigma(W_i·[h_{t-1},x_t]+b_i) \\ o_t = \sigma(W_o·[h_{t-1},x_t]+b_o) \end{cases}

  模型的当前状态:
ht=ottanh(Ct)h_t = o_t \odot \tanh(C_t)
其中Ct=ftCt1+itC~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_tC~t=tanh(WC[ht1,xt]+bC)\tilde{C}_t = \tanh(W_C·[h_{t-1},x_t]+b_C)。将hth_t展开后可以得到:
ht=tanh[σ(Wf[ht1,xt]+bf)+σ(Wi[ht1,xt]+bi)]h_t = \tanh[\sigma(W_f·[h_{t-1},x_t]+b_f) + \sigma(W_i·[h_{t-1},x_t]+b_i)]
  对比传统RNN模型中求偏导过程中的一项:
j=k+1thjhj1=j=k+1tfW\prod_{j=k+1}^t \frac{\partial h_j}{\partial h_{j-1}} = \prod_{j=k+1}^t f' W
  在LSTM中,该项变为
j=k+1thjhj1=j=k+1ttanhσ(Wfxt+bf)\prod_{j=k+1}^t \frac{\partial h_j}{\partial h_{j-1}} = \prod_{j=k+1}^t \tanh' \sigma(W_fx_t+b_f)
  令z=tanh(x)σ(y)z=\tanh'(x)\sigma(y),则zz的函数如下图所示:
神经网络中的梯度不稳
  可以看到该函数值基本上不是0就是1。

  因为j=k+1thjhj1=j=k+1ttanhσ(Wfxt+bf)01\prod_{j=k+1}^t \frac{\partial h_j}{\partial h_{j-1}} = \prod_{j=k+1}^t \tanh' \sigma(W_fx_t+b_f) \approx 0|1,从而可以解决传统RNN中梯度消失的问题。

  虽然但是,在遗忘门后,如果遗忘门接近1(如模型初始化时会把bfb_f设置成较大的正数,让遗忘门饱和),远距离的梯度不会消失;如果遗忘门接近0,更有可能是模型学到了某些特征(如文本中的 “not”、“but” 等)选择对前面数据进行遗忘。大多数情况下遗忘门仍然是一个0~1的数,LSTM仍然是有可能发生梯度消失的,只是概率远远低于RNN。

  因为在LSTM中,梯度的传播有很多条路径。其中ct1ct=ftct1+itc~tc_{t-1} \rightarrow c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t上只有逐元素相乘和相加的操作,梯度流最稳定;但是其他路径(如ct1ht1itctc_{t-1} \rightarrow h_{t-1} \rightarrow i_t \rightarrow c_t)上梯度流与传统RNN类似,照样会发生相同的权重矩阵反复连乘。由于=总的远距离梯度 = \sum 各条路径的远距离梯度,即便其他远距离路径梯度消失了,只要保证有一条远距离路径梯度不消失,总的远距离梯度就不会消失(+=正常梯度 + 消失梯度 = 正常梯度)。因此LSTM通过改善一条路径上的梯度问题拯救了总体的远距离梯度。

  LSTM仍然是有可能发生梯度爆炸的,但是因为回传路径复杂多样,而且和传统RNN相比多经过了很多次**函数,因此频率比较低。实际中梯度爆炸一般结合梯度裁剪 (gradient clipping) 解决。

四、其它解决梯度不稳的方法

4.1 预训练加微调

  此方法来自Hinton在2006年发表的一篇论文,提出采取无监督逐层训练方法,其基本思想是每次训练一层隐节点,训练时将上一层隐节点的输出作为输入,而本层隐节点的输出作为下一层隐节点的输入,此过程就是逐层“预训练”;在预训练完成后,再对整个网络进行“微调”。这种做法相当于是先寻找局部最优,然后整合起来寻找全局最优,但是目前应用的不是很多。
  预训练能够缓解梯度不稳定现象的原因在于其降低了参数训练对梯度下降法的依赖性。本质上讲梯度不稳定现象并没有消失,只是降低了它对模型训练结果的影响。此外,无监督学习的特征有利于避免过拟合现象的发生。

4.2 梯度裁剪与正则化

  梯度裁剪主要是针对梯度爆炸提出的,其思想是设置一个阈值,然后更新梯度的时候,如果梯度超过这个阈值,那么就将其强制限制在这个范围之内,从而防止梯度爆炸。
  另外一种解决梯度爆炸的手段是采用权重正则化。比较常见的是l1l1正则,和l2l2正则。正则化是通过对网络权重做正则限制过拟合,其损失函数的形式为:
L=(yWTx)2+αW2L = (y-W^Tx)^2+\alpha||W||^2
  其中,α\alpha是指正则项系数。如果发生梯度爆炸,权值的范数就会变的非常大,通过正则化项,可以部分限制梯度爆炸的发生。

4.3 **函数的角度

  ReLU:其设计思想为如果**函数的导数为1,那么就不存在梯度消失爆炸的问题,每层网络都可以得到相同的更新速度。
ReLU(x)=max(x,0)={0,x<0x,x>0ReLU(x) = \max(x,0) = \begin{cases} 0,x<0 \\ x,x>0 \end{cases}
神经网络中的梯度不稳

  ReLU函数的导数在正数部分是恒等于1的,因此在深层网络中使用ReLU**函数就不会导致梯度消失和爆炸的问题。ReLU缓解了梯度消失、爆炸的问题,而且计算方便,计算速度快,从而加速了网络的训练。但是由于其负数部分恒为0,会导致一些神经元无法**;此外它的输出不是以0为中心的。

  LeakReLU:LeakReLU就是为了解决ReLU的0区间带来的影响。
LeakReLU(x)=max(kx,x)={kx,x<0x,x>0LeakReLU(x) = \max(kx,x) = \begin{cases} kx,x<0 \\ x,x>0 \end{cases}
神经网络中的梯度不稳

  其中kk是leak系数,一般选择0.1或者0.2,或者通过学习而来

  eLU:eLU**函数也是为了解决eLU的0区间带来的影响。
eLU(x)={α(ex1),x<0x,x>0eLU(x) = \begin{cases} \alpha(e^x-1),x<0 \\ x,x>0 \end{cases}
神经网络中的梯度不稳

  相对于leakrelu来说,elu的计算量较大。

4.4 权重初始化角度

  • Xavier初始化+tanh**函数

  若**函数为线性恒等映射,则有
Var(al)=Var(zl)=fan_in×Var(w)×Var(al1)Var(a^l)=Var(z^l) = fan \_ in \times Var(w) \times Var(a^{l-1})

  而tanh**函数在0附近可近似为恒等映射,即tanh(x)xtanh(x) \approx x

  Xavier的初始化过程同时考虑了前向过程和反向过程,使用fan_infan\_ infan_outfan\_ out的平均数对方差进行归一化,权重服从如下分布的高斯采样:
W N(0,2fan_in+fan_out)W~N(0,\frac{2}{fan\_ in+fan\_ out})
  而由于均匀分布的方差与分布范围的关系为
Var(U(n,n))=n23Var(U(-n,n))=\frac{n^2}{3}
  令Var(U(n,n))=n23=2fan_in+fan_outVar(U(-n,n))=\frac{n^2}{3} = \frac{2}{fan\_ in+fan\_ out},则有
n=6fan_in+fan_outn = \frac{\sqrt{6}}{\sqrt{fan\_ in+fan\_ out}}
  即权重也可以从如下均匀分布中采样:
WU(6fan_in+fan_out,6fan_in+fan_out)W \sim U(-\frac{\sqrt{6}}{\sqrt{fan\_ in+fan\_ out}},\frac{\sqrt{6}}{\sqrt{fan\_ in+fan\_ out}})

  • Kaiming初始化+ReLU**函数

  由于有
Var(z)=fan_in×(Var(w)Var(a)+Var(w)E(a)2)Var(z) = fan \_ in \times (Var(w)Var(a) + Var(w)E(a)^2)
  但是对于ReLU来说,不可以像tanh一样将E(a)E(a)近似看作0,因此进行如下推导:
Var(z)=fan_in×(Var(w)Var(a)+Var(w)E(a)2)Var(z) = fan \_ in \times (Var(w)Var(a) + Var(w)E(a)^2)

=fan_in×(Var(w)(E(a2)E(a)2)+Var(w)E(a)2) = fan \_ in \times (Var(w)(E(a^2)-E(a)^2) + Var(w)E(a)^2)

=fan_in×Var(w)×E(a2)=fan \_ in \times Var(w) \times E(a^2)

  从而可以得到
Var(zl)=fan_in×Var(wl)×E((al1)2)Var(z^l)=fan \_ in \times Var(w^l) \times E((a^{l-1})^2)

  假定wl1w^{l-1}关于原点对称,则可以认为Zl1Z^{l-1}的分布也关于原点对称。
  对于一个关于原点0对称的分布,经过ReLU后,仅保留大于0的部分,则有
Var(x)=(x0)2p(x)dxVar(x) = \int_{-\infty}^\infty (x-0)^2p(x)dx
=2x2p(x)dx=2E(max(0,x)2)=2 \int_{-\infty}^\infty x^2p(x)dx=2E(max(0,x)^2)
  进一步可以得出:
Var(zl)=12×fan_in×Var(xl)×Var(zl1)Var(z^l) = \frac{1}{2} \times fan \_ in \times Var(x^l) \times Var(z^{l-1})

  将系数缩放为1,即
12×fan_in×Var(xl)=1\frac{1}{2} \times fan \_ in \times Var(x^l) = 1
  因此有
Var(w)=2fan_inVar(w) = \frac{2}{fan \_ in}

  即从前向传播考虑,每层的权重初始化为:
WN(0,2fan_in)W \sim N(0,\frac{2}{fan \_ in})
  同理,从后向传播考虑,每层的权重初始化为:
WN(0,2fan_out)W \sim N(0,\frac{2}{fan \_ out})

  • BatchNorm层

  Batchnorm具有加速网络收敛速度,提升训练稳定性的效果。其本质上是解决反向传播过程中的梯度问题。batchnorm全名是batch normalization,简称BN,即批规范化,通过规范化操作将输出信号x规范化保证网络的稳定性。

  简单来说,若有正向传播f2=f1(wTx+b)f_2 = f_1(w^T *x+b),那么反向传播就有f2w=f2f1x\frac{\partial f_2}{\partial w} = \frac{\partial f_2}{\partial f_1}x。反向传播式子中有xx的存在,所以xx的大小影响了梯度的消失和爆炸。Batchnorm就是通过对每一层的输出规范为均值和方差一致的方法,消除了xx带来的放大缩小的影响,进而缓解了梯度消失和爆炸的问题。

4.5 网络结构的角度

  使用残差网络用来解决深度网络的不容易训练和退化的问题。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WojtRu4F-1598503318395)(./res.png)]

  对于一个堆积层结构,当输入为xx时其学习到的特征记为H(x)H(x),设定残差为F(x)=H(x)xF(x)=H(x)-x,则原始的学习特征是F(x)+xF(x)+x。当残差为0时,此时堆积层仅仅做了恒等映射,至少网络性能不会下降。实际上残差不会为0,使得堆积层在输入特征基础上学习到新的特征,从而拥有更好的性能。
  直观上看残差学习需要学习的内容少,因为残差一般会比较小,学习难度小。残差单元可以表示为:
{yl=h(xl)+F(xl,Wl)xl+1=f(yl)\begin{cases} y_l=h(x_l)+F(x_l,W_l) \\ x_{l+1} = f(y_l) \end{cases}
  其中xx表示残差单元的输入,yy表示输出;ff是ReLU**函数,因此可以得出学习特征为:
xL=xl+i=lL1F(xi,Wi)x_L = x_l + \sum_{i=l}^{L-1} F(x_i,W_i)
  通过链式求导,计算反向传播的梯度:
Lxl=LxLxLxl=LxL(1+xLi=lL1F(xi,Wi))\frac{\partial L}{\partial x_l} = \frac{\partial L}{\partial x_L}·\frac{\partial x_L}{\partial x_l} = \frac{\partial L}{\partial x_L} · (1+\frac{\partial}{\partial x_L}\sum_{i=l}^{L-1}F(x_i,W_i))
  其中LxL\frac{\partial L}{\partial x_L}​表示的损失函数到达L的梯度,小括号中的1表明短路机制可以无损地传播梯度。残差梯度基本上不可能全为-1,而且有1的存在,可以缓解梯度消失。

4.6 损失函数的角度

  通过损失函数来抵消**函数求导后造成的梯度消失影响。常见的方式有如下几种:

  • linear+MSE loss;
  • sigmoid+BCE loss;
  • softmax+CrossEntropy loss.

五、总结

  在深度多层感知器网络中,梯度消失可能导致神经网络不稳定,使之难以收敛;而梯度爆炸则会使模型不能从训练数据中学习,甚至是无法更新的NaN权重值。
  而针对梯度不稳的处理方式,主要包括预训练加微调、梯度裁剪、权重正则、调整**函数、修改初始化权重、使用残差结构、使用LSTM网络等方式。