BN层的反向传播
参考链接:
BN层的参考链接1
---- 前言
BN层的的公式:输入进行均值和方差,然后归一化,接着有两个参数,分别是scale和shift,其实一定程度上可以让BN层选择归一化恢复多少。有时候可以刚刚好等于均值和方差,那么就抵消了。
这里就说两个参数的意义,有时候我们并不想归一化的进入,这时候重要两个参数等于方差和均值,那么就可以抵消了归一化。
----主体部分
前向传播:
根据上面的公式假设我们现在只有两个input然后output也只有两个的话,一个BN层的内部结构如下,改变输入数据的分布,让模型自己去学习输入数据的分布就是我们的BN层的核心:
所以要求BN层的反向传播一般我们就要知道损失函数对X1和X2的梯度,求出这两个基本上就可以梯度继续反向传播了,同时BN层也有两个需要更新的参数,这两个也需要计算损失函数对他们的梯度。所以要一共需要计算损失函数对输入的梯度,对两个参数的梯度。我们对输入的梯度计算是为了确保链式法则和反向传播得意继续,而对参数的梯度计算则是为了更新参数本身,这两者是不同的目的的。
首先我们对两个参数进行更新,所以我们需要计算损失函数对参数本身的梯度。也就是????和????。
对????的梯度就是损失函数对y的梯度:
对????梯度:损失函数对y的梯度与两个新的x的乘积:
接下来我们需要计算对两个新的x的梯度:因为我们后续的这样符合我们的链式法则和反向传播一步步的计算:
然后我们继续反向传播继续求梯度,一步步的接近我们的输入x,因为我们的x影响了我们的均值和方差,也影响了我们的新的两个x的生成,所以除了求损失函数对两个新的x的梯度我们还需要求损失函数对均值和方差的梯度。这两个的计算比较复杂要注意啦~~~~,因为均值和方差两个不仅相互影响,而且影响了两个新的x。所以计算梯度需要都考虑进去。
注意的是上面的,对均值的梯度后面部分是0 ,可以看看他的上面那一部分。到这里我们就求出来了均值和方差的梯度,还有两个新的x的梯度,这些都是被input影响的,那么现在就可以求损失函数对input的x的梯度了:
所以最终对每个xi的梯度是这样的: