CBN:Cross-Iteration Batch Normalization

论文:https://arxiv.org/abs/2002.05712
代码:https://github.com/Howal/Cross-iterationBatchNorm
https://github.com/Howal/Cross-iterationBatchNorm/blob/master/mmdet/models/utils/CBN.py

1 简介

1.1 BN

BN(Batch Normalization)在CNN的发展过程中起到了重要的作用,其主要是解决internel covariate shift,也就是在不同轮的迭代过程中,因为网络权重持续变化造成的中间层**值的分布变化问题。BN通过在训练过程中减去mini-batch的均值和方差对**值进行归一化处理以减少分布的变化,同时引入了可学习的变换系数γ\gammaβ\beta进行表示能力的补偿。使用BN可以使用较大的学习率进行网络训练,也不需要过于关注权重的初始值,并且训练过程中也有一定的防止过拟合的能力。

BN的计算过程可以表示为:

θt\theta_txt,i(θt)x_{t,i}(\theta_t)表示网络的权重值和第t个mini-batch中第i个样本的在某一层的**值。
ut(θt)u_t(\theta_t)σt(θt)\sigma_t(\theta_t)表示当前mini-batch中所有样本的均值和方差。
CBN:Cross-Iteration Batch Normalization
CBN:Cross-Iteration Batch Normalizationvt(θt)=1mi=1mxt,i(θt)2v_t(\theta_t) = \frac{1}{m}\sum_{i=1}^mx_{t,i}(\theta_t)^2,m表示mini-batch中的样本个数。归一化后的**值x^t,i(θt)\hat x_{t,i}(\theta_t)再进行下面所示的缩放操作:

CBN:Cross-Iteration Batch Normalization
γ\gammaβ\beta是可学习的参数。对于全连接层,是每一个神经元一组参数,对卷积层,是每一个输出channel一组参数,参数量等于当前层的输出channel数。

BN成功使用的前提是假设训练过程中各batchsize的训练样本和总体样本集符合相同的分布,BN的缺点是batchsize较小时,基于batchsize的样本估计出的均值和方差和总体样本集的分布就相差甚远,因此也就无法的改善模型的训练效果。在基于ImageNet数据集训练分类模型时,最常用的batchsize大小为32,但是在其他输入图像更大的应用场景下,如目标检测和语义分割任务,因为GPU显存的限制,batchsize往往很小,这种情况下BN的作用就非常有限了。

1.2 规范化

在CNN中,规范化可用于三个地方,分别为:输入数据、隐藏层的**值和网络参数。

1.2.1 输入数据的规范化

输入数据规范化是最常用的规范化手段,其好处可以参考:https://blog.****.net/cdknight_happy/article/details/106453319

1.2.2 隐藏层**值规范化

针对BN的缺点,后续出现了LN、IN、GN和BIN。LN是针对单个channel进行规范化,IN是针对单个样本的单个channel进行规范化,GN是针对单个样本的单个channel组进行规范化。CGBN和SyncBN是针对跨GPU的数据进行规范化,目的是增大有效的batchsize的大小。kalman规范化是对网络的某一层及其前面层的**值应用kalman滤波进行规范化。

1.2.3 网络权重规范化

Weight Normalization(WN),Wight Standardization(WS),Switch Normalization(SN),Sparse Switchable Normalization(SSN)

2 CBN

CBN的目的也是为了解决batchsize小时BN效果不好的问题,思路是跨越不同轮的迭代进行均值和方差的整合。但是如果直接对不同轮迭代产生的**值的均值和方差进行整合效果不好,因为每一轮的迭代输出都是基于当时的网络权重计算得到的。

因为网络训练过程中,网络权重在连续迭代过程中是在平滑变化的,所以可以通过泰勒多项式使用utτ(θtτ)u_{t - \tau}(\theta_{t - \tau})vtτ(θtτ)v_{t - \tau}(\theta_{t - \tau})近似utτ(θt)u_{t - \tau}(\theta_{t})vtτ(θt)v_{t - \tau}(\theta_{t })
CBN:Cross-Iteration Batch Normalization
O(θtθtτ2)O(||\theta_t - \theta_{t - \tau}||^2)表示泰勒展开式中的高阶项,因为(θtθtτ)(\theta_t - \theta_{t - \tau})值很小,所以高阶项可以被忽略。

式(5)和(6)中,utτ(θtτ)/θtτ\partial u_{t - \tau}(\theta_{t - \tau})/\partial \theta_{t - \tau}vtτ(θtτ)/θtτ\partial v_{t - \tau}(\theta_{t - \tau})/\partial \theta_{t - \tau}难以使用很小的计算量进行精确计算,这是因为utτl(θtτ)u^l_{t - \tau}(\theta_{t - \tau})vtτl(θtτ)v^l_{t - \tau}(\theta_{t - \tau})与比ll小的层都有关系,比如θtτl1\theta^{l-1}_{t - \tau}有变化,utτlu^l_{t - \tau}也会跟着改变。所以对与rlr \leq lutτl(θtτ)/θtτr0\partial u^l_{t - \tau}(\theta_{t - \tau})/\partial \theta^r_{t - \tau} \neq 0vtτl(θtτ)/θtτr0\partial v^l_{t - \tau}(\theta_{t - \tau})/\partial \theta^r_{t - \tau} \neq 0。但是,当r<lr < l时,utτl(θtτ)/θtτr\partial u^l_{t - \tau}(\theta_{t - \tau})/\partial \theta^r_{t - \tau}vtτl(θtτ)/θtτr\partial v^l_{t - \tau}(\theta_{t - \tau})/\partial \theta^r_{t - \tau}减小的很快,这也是BN中网络的前面层比后面层的internal
covariate shift要小的原因。所以,作者再次进行了近似,忽略一次迭代中当前层均值相对于前面层参数的梯度,得到:
CBN:Cross-Iteration Batch Normalization

上面计算出了单次迭代中各层的均值和方差,那么CBN是对k次迭代均值和方差再进行平均,CBN的均值和方差为:
CBN:Cross-Iteration Batch NormalizationCBN的规范化即:

CBN:Cross-Iteration Batch Normalization

CBN中,因为是使用k轮迭代的联合均值和方差,因此训练过程中的有效batchsize值相比于普通的BN被增大了k倍。

将BN换成CBN,计算量和内存占用的增加都很少,但有效batchsize值则进行了很大程度的扩充。对比见下表:
CBN:Cross-Iteration Batch Normalization
CBN和其他规范化手段是在不同的维度进行操作的,因此CBN可以和其他规范化手段结合使用。

超参数k
CBN中引入了一个新的超参数k — 迭代的轮数。k越大,有效的batchsize也越大,但是k太大时,因为迭代过程中参数的改变太大会造成参与求均值的样本的变化太大,会起到负面的效果。作者实验发现,设置k=8在大部分实验中可以取得比较好的效果。网络训练初期,权重改变的很快,应该使用较小的k值。所以,作者提出了一个TburninT_{burn-in}表示热身的epoch的次数,在图像分类实验中,设置TburninT_{burn-in}为25,目标检测实验中TburninT_{burn-in}为3,在TburninT_{burn-in}期间,设置k=1,CBN退化为BN。TburninT_{burn-in}结束后,再设置k=8应用CBN进行训练。

CBN:Cross-Iteration Batch Normalization

CBN:Cross-Iteration Batch Normalization
CBN:Cross-Iteration Batch Normalization

3 实验

CBN:Cross-Iteration Batch Normalization

上图是batch=32时的分类结果,可以看出对于大的batchsize,CBN和BN一样好。

CBN:Cross-Iteration Batch NormalizationCBN:Cross-Iteration Batch Normalization

随着batchsize的减小,BN的效果在快速衰减,但CBN仍然保持很好的效果。

CBN:Cross-Iteration Batch Normalization
建议k的取值为:
CBN:Cross-Iteration Batch Normalization

CBN:Cross-Iteration Batch Normalization