深度残差网络ResNet

paper https://arxiv.org/pdf/1512.03385.pdf

提出动机

ResNet是为了解决深度神经网络中由于层数过多带来的模型退化问题(degradation)。
一般情况下,模型退化主要有以下几种原因:

  • 过拟合,层数越多,参数越复杂,泛化能力弱
  • 梯度消失/梯度爆炸,层数过多,梯度反向传播时由于链式求导连乘使得梯度过大或者过小,使得梯度出现消失/爆炸,对于这种情况,可以通过BN(batch normalization)可以解决
  • 由深度网络带来的退化问题,一般情况下,网络层数越深越容易学到一些复杂特征,理论上模型效果越好,但是由于深层网络中含有大量非线性变化,每次变化相当于丢失了特征的一些原始信息,从而导致层数越深退化现象越严重。

如下图所示,56层的网络要比20层的网络效果差,出现退化现象。
深度残差网络ResNet
对于深层网络的退化现象,何凯明大神希望用一种方式,使得深层神经网络至少能和浅层神经网络相持平(深层不能比浅层的差),因此设计了一种残差结构来解决该问题。

这里说个概念:恒等映射,让x的映射等于x,即H(x)=xH(x) = x
如下图所示,左边为浅层网络,右边为深层网络,如果想让左右两边持平,就得让后面的五层网络进行恒等映射,即输入等于输出,相当于不起任何作用。ResNet就是利用这种恒等映射,使得网络加深时,至少能保证和浅层网络相持平。
深度残差网络ResNet

残差单元设计

基于上述思想,希望设计一个如下的网络结构,输入x,输出H(x)=F(x)=xH(x) =F(x)=xF(x)F(x)表示残差单元的输出,H(x)H(x)表示期望的输出,即恒等映射。
深度残差网络ResNet
但是有种更巧妙的方式,将残差单元设计成如下的方式
深度残差网络ResNet
H(x)=F(x,w)+xH(x) = F(x,w) + x,所以F(x,w)=H(x)xF(x,w) = H(x) - x,即残差形式,ww为残差结构中的参数矩阵。当H(x)H(x)为恒等变换时,F(x,w)=0F(x,w)=0
这样设计非常巧妙,主要有以下优点:

  1. 首先直观理解,如上图所示,x通过shortcuts(跳接操作)作为残差单元的一部分,如果图中的两层网络对结果无益,那么直接通过跳接输出结果(即w=0w=0),即输入等于输出;
  2. 使得残差单元中参数学习更加容易,如果直接使得F(x,w)=H(x)=xF(x,w)=H(x)=x,那么ww通常是个非稀疏矩阵,如果令F(x,w)=H(x)xF(x,w)=H(x)-x,在恒等映射情况下,F(x,w)=0=>w=0F(x,w)=0=>w=0,而参数初始化时一般都在0附近,使其产生恒等映射更加容易;
  3. H(x)=F(x,w)+1H'(x) = F'(x,w) + 1,除非F(x,w)=1F'(x,w)=-1,否则不会出现梯度消失的情况;

在实际应用中需要额外的权重矩阵来衔接维度不同的网络单元。
总之,个人理解是ResNet是通过引入了一个巧妙的残差结构,在深层网络中,让模型自动取舍是否需要某些网络或者哪些网络不需要起太大作用。