对于 ResNet 残差网络的思考——残差网络可以解决梯度消失的原因

导言:

从神经网络的历史上来看,深层网络由于梯度消失无法训练这个问题目前为止一共有两次很大的突破。第一次是神经网络开山鼻祖 Hinton 先生提出的 relurelu **函数取代了原来的 sigmoidsigmoidtanhtanh 函数,使得对于**函数的导数变为了 11 。第二次是 何凯明 大神在 2015 年的论文 Deep Residual Learning for Image Recognition 中使用残差模块利用 shortcut 解决了深层网络消失的问题,使得训练数百层甚至数千层的网络成为了可能。

本文接着此前发表的博文:对于 ResNet 残差网络的思考——残差网络性能优于原始浅层网络的原因

深入思考了 ResNet 残差网络是如何解决梯度消失问题的,同时本文也对神经网络反向传播的计算提出了一种计算图,方便在以后研究过程中,可以通过直接看图,来了解在所构建的神经网络在反向传播过程中的具体细节。

1.问题:

最近读神经网络文章时遇到了一个问题。一开始神经网络使用 sigmoidsigmoid 以及 tanhtanh 函数,如下图所示为 sigmoidsigmoid 以及 tanhtanh 函数及其导数示意图。
对于 ResNet 残差网络的思考——残差网络可以解决梯度消失的原因
对于 ResNet 残差网络的思考——残差网络可以解决梯度消失的原因
这种**函数正如网上的大多数博文所说的一样,由于其导数在大部分区域内小于 11,会使得多层网络在反向传播时很多个小于1的导数不断相乘最后使得梯度越来越小,反向传播最后到输入层附近时几乎不更新。

历史上,深度学习开山祖师 Hinton 为了解决这个问题提出了一种新的**函数就是 relurelu,它的函数及其导数图如下所示。
对于 ResNet 残差网络的思考——残差网络可以解决梯度消失的原因
可以看到这种**函数的导数在大于 00 的时候恒为1,这样就不会出现小于 11 的导数不断相乘导致梯度消失的问题。

按道理来说,梯度消失的问题就已经解决了呀,Happy Ending 本文完。但梯度消失的故事才刚刚开始 -_-!,梯度消失的问题曾一度将深度学习开山鼻祖 Hinton 先生逼到一层一层的训练其在 Science 提出的自编码器(auto-encoder)。这也是我一直困惑的问题,都有了 relurelu **函数解决梯度消失问题了,那后面 何凯明 大神提出来的 ResNet 为什么又将梯度消失问题解决了一遍?什么意思只用 relurelu **函数为什么不能训练深层网络呀?毕竟 relurelu **函数的导数都是 11 呀,问题出在了哪里?

2.梯度计算图:

为了更深入的理解这个问题。

PS: 而不是像其他分析 ResNet 残差网络的博文那样直接说残差网络通过引入恒等映射,恒等映射求导是 11 解决了梯度消失的问题。其实想想也知道这个解释不靠谱,因为 relurelu 在大于 0 的时候也是恒等映射。那不是说 ResNet 残差网络是多余的吗。

我创造了一种能够清楚表示神经网络反向传播时计算细节的梯度计算图。

2.1 求导链式法则的图形化表示

在开始画出神经网络的梯度计算图之前,我们首先来用梯度计算图画出 高等数学 中经常会用到的多元函数求导链式法则。这也是神经网络反向传播算法的基础。

对于多元函数求导链式法则来说,其数学表示式如下:

Fx=Fmmx+Fnnx\frac{\partial F}{\partial x}=\frac{\partial F}{\partial m}\frac{\partial m}{\partial x}+\frac{\partial F}{\partial n}\frac{\partial n}{\partial x}

其中 m,n 为中间变元。这样一个求导链式法则可以用如下的图形化的语言描述
对于 ResNet 残差网络的思考——残差网络可以解决梯度消失的原因
可以看到从 FF 走到 xx 一共有两条路,可以先由 FF 走到 mm 最终到 xx,也可以选择中间经过 nnxx。两种不同路径的选择用加号 ++ 相加,从 FF 走到 mmFm\frac{\partial F}{\partial m} 表示,从 mm 走到 xxmx\frac{\partial m}{\partial x} 表示。同一条路径的先后次序用乘号 ×\times 相乘。由此可以得到上面的数学表达式。同时我们使用方框来表示各个变量,由此来和神经网络示意图中表示神经元的符号 圆圈 相区分。

2.2 全链接网络 梯度计算图

如下图所示,左边是一个全链接神经网络示意图,右边是一个全链接网络 梯度计算图。由于在电脑上画图实在太难画了,我没有将它补充完整。
对于 ResNet 残差网络的思考——残差网络可以解决梯度消失的原因
由此我们可以直接通过看图来得到反向传播过程中权重是如何具体更新的。

比如 图中的权重 wn111w^{11}_{n-1} 来说,我们可以通过看图很容易写出它反向传播更新时的数学表达式。

Losswn111=Lossyn1yn1xn1xn1wn111\frac{\partial Loss}{\partial w^{11}_{n-1}}=\frac{\partial Loss}{\partial y^{1}_{n}}\frac{\partial y^{1}_{n}}{\partial x^{1}_{n}}\frac{\partial x^{1}_{n}}{\partial w^{11}_{n-1}}

其在图中的意义为:从 LossLoss 出发走到 wn111w^{11}_{n-1} 只有一条路,它从 LossLoss 出发经过 yn1y^{1}_{n} 再经过 xn1x^{1}_{n} 即可到达 wn111w^{11}_{n-1}

同时在上图的表示中我们可以看到这样的三叉路口一样的图形。
对于 ResNet 残差网络的思考——残差网络可以解决梯度消失的原因
它在正向传播计算中表示加权求和,但是在反向传播的 梯度计算图中的规则更为简单,满足以下关系。

xn1wn111=yn11\frac{\partial x^{1}_{n}}{\partial w^{11}_{n-1}}=y^{1}_{n-1}

xn1yn11=wn111\frac{\partial x^{1}_{n}}{\partial y^{1}_{n-1}}=w^{11}_{n-1}

表示从 xn1x_{n}^{1} 走到 wn111w^{11}_{n-1} 的结果为 yn11y^{1}_{n-1},从 xn1x_{n}^{1} 走到 yn11y^{1}_{n-1} 的结果为 wn111w^{11}_{n-1}

同样,我们也可以通过直接看 梯度计算图。 从图中我们可以看到从 LossLoss 走到 wn211w^{11}_{n-2} 存在着多条不同路径,yn1y^{1}_{n}yn2y^{2}_{n}yn3y^{3}_{n}yn4y^{4}_{n} 所代表的路径都可能,如下图红色线条所示。

对于 ResNet 残差网络的思考——残差网络可以解决梯度消失的原因

由此我们可以得到 权重 wn211w^{11}_{n-2} 反向传播时的数学表达式。

Losswn111=Lossyn1yn1xn1xn1yn11yn11xn11xn11wn211+Lossyn2yn2xn2xn2yn11yn11xn11xn11wn211+Lossyn3yn3xn3xn3yn11yn11xn11xn11wn211+Lossyn4yn4xn4xn4yn11yn11xn11xn11wn211\frac{\partial Loss}{\partial w^{11}_{n-1}}= \frac{\partial Loss}{\partial y^{1}_{n}}\frac{\partial y^{1}_{n}}{\partial x^{1}_{n}}\frac{\partial x^{1}_{n}}{\partial y^{1}_{n-1}}\frac{\partial y^{1}_{n-1}}{\partial x^{1}_{n-1}}\frac{\partial x^{1}_{n-1}}{\partial w^{11}_{n-2}} \\+\frac{\partial Loss}{\partial y^{2}_{n}}\frac{\partial y^{2}_{n}}{\partial x^{2}_{n}}\frac{\partial x^{2}_{n}}{\partial y^{1}_{n-1}}\frac{\partial y^{1}_{n-1}}{\partial x^{1}_{n-1}}\frac{\partial x^{1}_{n-1}}{\partial w^{11}_{n-2}} \\+\frac{\partial Loss}{\partial y^{3}_{n}}\frac{\partial y^{3}_{n}}{\partial x^{3}_{n}}\frac{\partial x^{3}_{n}}{\partial y^{1}_{n-1}}\frac{\partial y^{1}_{n-1}}{\partial x^{1}_{n-1}}\frac{\partial x^{1}_{n-1}}{\partial w^{11}_{n-2}} \\+\frac{\partial Loss}{\partial y^{4}_{n}}\frac{\partial y^{4}_{n}}{\partial x^{4}_{n}}\frac{\partial x^{4}_{n}}{\partial y^{1}_{n-1}}\frac{\partial y^{1}_{n-1}}{\partial x^{1}_{n-1}}\frac{\partial x^{1}_{n-1}}{\partial w^{11}_{n-2}}

3. relu 和 ResNet 各自解决了什么问题

由于我们可以从 梯度计算图 中很容易得到 权重更新的数学计算式。我们可以很清晰地思考 relurelu **函数和 ResNet 残差网络分别解决了什么问题。

3.1 relu 解决梯度消失问题

对于 relurelu 来说主要是使得 **函数 的导数为 11, 从 上面的 梯度计算图中来说是使得下面这一类的路径的导数为 11 ,即直接变为一条直线。
对于 ResNet 残差网络的思考——残差网络可以解决梯度消失的原因
即对于没有死掉的神经元来说 yn1xn1=1\frac{\partial y^{1}_{n}}{\partial x^{1}_{n}}=1,对于死掉的神经元(relu dead) yn1xn1=0\frac{\partial y^{1}_{n}}{\partial x^{1}_{n}}=0 直接从求和计算中去掉。

由此我们可以简化上面的 wn111w^{11}_{n-1}wn211w^{11}_{n-2} 的数学计算式。

Losswn111=Lossyn1xn1wn111\frac{\partial Loss}{\partial w^{11}_{n-1}}=\frac{\partial Loss}{\partial y^{1}_{n}}\frac{\partial x^{1}_{n}}{\partial w^{11}_{n-1}}

Losswn111=Lossyn1xn1yn11xn11wn211+Lossyn2xn2yn11xn11wn211+Lossyn3xn3yn11xn11wn211+Lossyn4xn4yn11xn11wn211\frac{\partial Loss}{\partial w^{11}_{n-1}}= \frac{\partial Loss}{\partial y^{1}_{n}}\frac{\partial x^{1}_{n}}{\partial y^{1}_{n-1}}\frac{\partial x^{1}_{n-1}}{\partial w^{11}_{n-2}} \\+\frac{\partial Loss}{\partial y^{2}_{n}}\frac{\partial x^{2}_{n}}{\partial y^{1}_{n-1}}\frac{\partial x^{1}_{n-1}}{\partial w^{11}_{n-2}} \\+\frac{\partial Loss}{\partial y^{3}_{n}}\frac{\partial x^{3}_{n}}{\partial y^{1}_{n-1}}\frac{\partial x^{1}_{n-1}}{\partial w^{11}_{n-2}} \\+\frac{\partial Loss}{\partial y^{4}_{n}}\frac{\partial x^{4}_{n}}{\partial y^{1}_{n-1}}\frac{\partial x^{1}_{n-1}}{\partial w^{11}_{n-2}}

同时使用上面我们提到的三叉路口的化简方法
我们可以得到如下非常优美的式子

Losswn111=Lossyn1yn11\frac{\partial Loss}{\partial w^{11}_{n-1}}=\frac{\partial Loss}{\partial y^{1}_{n}}y^{1}_{n-1}

Losswn111=Lossyn1wn111yn21+Lossyn2wn112yn21+Lossyn3wn113yn21+Lossyn4wn114yn21\frac{\partial Loss}{\partial w^{11}_{n-1}}= \frac{\partial Loss}{\partial y^{1}_{n}}w^{11}_{n-1}y^{1}_{n-2} \\+\frac{\partial Loss}{\partial y^{2}_{n}}w^{12}_{n-1}y^{1}_{n-2} \\+\frac{\partial Loss}{\partial y^{3}_{n}}w^{13}_{n-1}y^{1}_{n-2} \\+\frac{\partial Loss}{\partial y^{4}_{n}}w^{14}_{n-1}y^{1}_{n-2}

可以从上述化简中看到,之前讨论 relurelu **函数的博客文章说的不错,relurelu **函数利用自己导数为 11 的特性解决了 梯度消失 的问题。
但为什么留下来的这个简洁的式子仍然有梯度消失的问题呢?

3.2 ResNet 解决梯度消失问题

我们接着分析上面化简得到的这个优美的式子。相比于wn111w^{11}_{n-1}来说,我们可以看到 wn211w^{11}_{n-2} 的计算式子中,有很多最后一层的权重因子相当于加权求和了。这是因为在 梯度计算图 中经过了 如下的权重层导致有多条路径可以选择的结果。
对于 ResNet 残差网络的思考——残差网络可以解决梯度消失的原因
那我们在进一步如果我们计算 wn211w^{11}_{n-2} 反向传播的数学计算式,在梯度计算图它会经过两个权重层,带来的结果是会在前面乘上两个权重层的权重。这会使得梯度降低吗? 答案是 会的,而且层数越多越明显最后仍然会导致梯度消失的结果。

我们可以简单分析一下。在上面计算wn211w^{11}_{n-2}梯度的计算式中,如果权重 wn111w^{11}_{n-1}wn112w^{12}_{n-1}wn113w^{13}_{n-1}wn114w^{14}_{n-1} 求和值小于 1 会如何? 答案是在前面乘的权重次数越多最后得到的梯度越小。

今天太晚了,明天我再接着写关于梯度消失以及 ResNet 梯度计算图,如何解决这一问题的。