再探反向传播算法(推导)

之前也写过关于反向传播算法中几个公式的推导,最近总被人问到其中推导的细节,发现之前写的内容某在些地方很牵强,很突兀,没有一步一步紧跟逻辑(我也不准备修正,因为它也代表了一种思考方式)。这两天又重新回顾了一下反向传播算法,所有就再次来说说反向传播算法。这篇博文的目的在于要交代清楚为什么要引入反向传播算法,以及为什么它叫反向传播。

1.从前(正)向传播谈起

在谈反向传播算法之前,我们先来简单回顾一下正向传播(详细版戳此处)。假设有如下网络结构:
再探反向传播算法(推导)

其中:

L=神经网络总共包含的层数Sl=l层的神经元数目K=输出层的神经元数,亦即分类的数目wijl=ljl+1i

即对如上网络结构来说,L=3,s1=3,s2=2,s3=K=2ail表示第l层第i个神经元的**值,bl表示第l层的偏置。

则有如下正向传播过程:

z12=a11w111+a21w121+a31w131+b1z22=a11w211+a21w221+a31w231+b1[z12z22]=[w111w121w131w211w221w231]2×3×[a11a21a31]3×1+[b1b1]z2=a1w1+b1a2=f(z2)z3=a2w2+b2a3=f(z3)

所以可以得出正向传播过程几个公式:

(1)zil+1=a1lwi1l+a2lwi2l++aSllwiSll+bl(2)zl+1=alwl+bl(3)al=f(zl)

其中,f()表示**函数,如sigmoid函数。

现在我们已经知道了正向传播的过程,也就是说当我们训练得到参数w之后,就可以用正向传播通过网络来预测了。但是大家有没有想过,参数w是怎么训练得到的?那第一反应肯定是运用梯度下降算法。既然是用梯度下降算法来求解参数,那第一步当然就是求解梯度了。

2.求解梯度

为了方便阅读,在这个位置再插入一张上面同样的网络结结构图:

再探反向传播算法(推导)

此时,我们假设网络的目标函数为误差平方函数,且暂时不管正则化,同时只考虑一个样本即:

J=12(hw,b(x)y)2

且此处hw,b(x)=a3
由此,我们可以发现:如果Jw111求导,则J是关于a3的函数,a3是关于z3的函数,z3是关于a2的函数,a2是关于z2的函数,w111是关于z2的函数。

为了更加清晰下面的求导过程,我们先来举两个例子,看看链式求导的过程(如果熟悉链式求导规则,请直接忽略)。


例1:
假设有如下函数:

f=sin(t),t=x2,x=5wfw=fttxxw=cos(t)2x5=cos(x2)2x5=cos(25w2)10w5=50wcos(25w2)

作为验证,我们直接将t,x带入f然后求导:

f=sin(x2)=sin(25w2)fw=cos(25w2)50w=50wcos(25w2)

例2:
我们再来看一个抽象的,没有表达式得链式求导,假设有如下函数表达式:

f=g(t),t=ϕ(x+y),x=h(w),y=μ(w)

则我们可以画出如下关系图:
再探反向传播算法(推导)
即,tf的函数,yx都是t的函数,w分别又都是yx的函数,也就是说我们有两条路径可以到达w,所以
fw=fttyyw+fttxxw=ft(tyyw+txxw)


所以有:

Jw111=Ja13a13z13z13a12a2z12z12w111+Ja23a23z23z23a12a2z12z12w111Jw121=Ja13a13z13z13a12a2z12z12w121+Ja23a23z23z23a12a2z12z12w121Jw222=Ja23a23z23z23w222

我们可以发现,当J对第2层的参数求导还相对不麻烦,但当J对第1层的参数求导的时候就做了很多重复的计算;并且这还是网络相对简单的时候,要是网络相对复杂一点,这个过程简直就是难以下手。这也是为什么神经网络在一段时间发展缓慢的原因,就是因为没有一种高效的计算梯度的方式。

3.一种高效的梯度求解办法

Jw111=(Ja13a13z13z13a12a2z12)z12w111+(Ja23a23z23z23a12a2z12)z12w111

从上面的求导公式可以看出,不管你是从哪一条路径过来,在对w111求导之前都会先到达z12,即先对z12求导之后,才会有z12w111。也就是说,我不管你是经过什么样的路径,在对连接第l层第j个神经元与第l+1i个神经元的参数wijl求导之前,肯定会先对zil+1求导。因此,对任意参数的求导过程,可以改写为:

(4)Jwijl=Jzil+1zil+1wijl=Jzil+1ajl

例如:

Jw111=Jz11+1z11+1w111=Jz12z12w111

所以,现在的问题变成了如何求解红色部分了,即:

Jzil+1=???

从网络结构图可以,J对任意zil求导,求导路径必定会经过第l+1层的所有神经元,于是有:

Jzil=Jz1l+1z1l+1zil+Jz2l+1z2l+1zil++JzSl+1l+1zSl+1l+1zil=k=1Sl+1Jzkl+1zkl+1zil=k=1Sl+1Jzkl+1zil(a1lwk1l+a2lwk2l++aSllwkSll+bl)1=k=1Sl+1Jzkl+1zilj=1Slajlwkjl=k=1Sl+1Jzkl+1zilj=1Slf(zjl)wkjl(5)=k=1Sl+1Jzkl+1f(zil)wkil

于是我们得到:

(6)Jzil=k=1Sl+1Jzkl+1f(zil)wkil

因此

Jzil+1=k=1Sl+2Jzkl+2f(zil+1)wkil+1

为了便于书写和观察规律,我们引入一个中间变量δil=Jzil,则(5)得:

(7)δil=Jzil=k=1Sl+1δkl+1f(zil)wkil(l<=L1)

注:之所以要l<=L1,是因为由(5)得推导过程可知,l最大只能取到L1,第L层后面没有网络层了。

所以:

δiL=JziL=ziL[12k=1SL(hk(x)yk)2]=ziL[12k=1SL(f(zkL)yk)2]=[f(ziL)yi]f(ziL)(8)=[aiLyi]f(ziL)

同时将(7)带入(4)可知:

(9)Jwijl=δil+1ajl

通过上面的所有推导,我们可以得到如下3个公式:

Jwijl=δil+1ajlδil=Jzil=k=1Sl+1δkl+1f(zil)wkil(0<lL1)δiL=[aiLyi]f(ziL)

且经过适量化后为:

(10)Jwl=δl+1(al)T(11)δl=(wl)Tδl+1f(zl)(12)δL=[aLy]f(zL)

符号表示矩阵乘法;符号表示两个矩阵相同位置的元素对应相乘

由(10)(11)(12)分析可知,欲求Jwl的导数,必先知道δl+1;而欲知δl+1,必先求δl+2,以此类推……
由此可知对于整个求导过程,一定是先求δL,再求δL1,一直到δ2

为了方便阅读,在这个位置再插入一张上面同样的网络结结构图:

再探反向传播算法(推导)

对于这样一个网络结构,整个求导过程(不含bl)如下:

Step1:δ3=[a3y]f(z3)Step2:Jw2=δ3(a2)TStep3:δ2=(w2)Tδ3f(z2)Step4:Jw1=δ2(a1)T

于是我们终于发现了这么一个不争的事实:
1.最先求解出导数的参数一定位于第L1层上(如此处的w2);
2.要想求解第l层参数的导数,一定会用到第l+1层上的中间变量δl+1(如此处求解w1的导数,用到了δ2);
3.整个过程是从后往前的;

所以,该过程被形象的称为反向(后向)传播算法。
另:δl被称为第l层的“残差”

一个重要的结论:
反向传播算法是用来求解梯度的!

反向传播算法是用来求解梯度的!

反向传播算法是用来求解梯度的!

重要的话说三遍,因为不少人总是把梯度下降和反向传播两个搞得稀里糊涂的。

4.总结

通过举例对平方误差目标函数反向传播算算法公式的推导,我们可以总结出更为一般的情况,即:

(13)Jwl=δl+1(al)T(14)δl=(wl)Tδl+1f(zl)(15)δiL=JziL=JaiLaiLziL=JaiLf(ziL)ziL=JaiLf(ziL)(16)Jbl=δl+1

我们可以看到,仅仅只有公式(15)才依赖于不同的目标函数;比如在交叉熵中δiL=aLy推导戳此处.

关于反向传播算法的推导基本上可以告一段落了,下一篇我们将通过一个例子用python来实现,这样就会更清楚了 。