深度学习中的正则化处理Normalization

Normalization

Batch Normaliation

批标准化处理,批:指一批数据,通常为mini-batch;标准化:0均值,1方差。

  • 可以用更大学习率,加速模型收敛
  • 可以不用精心设计权值初始化
  • 可以不用dropout或者较小的dropout
  • 可以不用L2或者较小的Weight decay
  • 可以不用LRN(local response normalization)

计算式

深度学习中的正则化处理Normalization

其中,normalize步骤中ϵ\epsilon为修正项,为了防止分母为零的情况出现。处理后x^\hat{x}即为0均值1方差,但BN算法仍未结束,最后还需要进行一步Affine Transform,即γxi^+β\gamma\hat{x_i}+\beta其中γ\gammaβ\beta称为scale与shift,这两个参数是可学习的,可通过反向传播改变。

_BatchNorm

pytorch中的Batch Normalization实现

  • num_features:一个样本特征数量(最重要)
  • eps:分母修正项
  • momentum:指数加权平均估计当前mean/var
  • affine:是否需要affine transform,默认为True
  • track_running_stats:是训练状态,还是测试状态

以下三个方法具体的实现都继承于_BatchNorm

  • nn.BatchNorm1d
  • nn.BatchNorm2d
  • nn.BatchNorm3d

主要属性:

  • running_mean:均值
  • running_var:方差
  • weight:affine transform中的gamma
  • bias: affine transform中的beta

均值和方差根据训练与测试的不同,拥有不同的计算式。
在训练时:不止考虑当前时刻,还会考虑之前的结果,计算式如下:
meanrunning=(1momentum)meanpre+momenturmmeantvarrunning=(1momentum)varpre+momentumvart mean_{running} = (1 - momentum) * mean_{pre} + momenturm * mean_t\\ var_{running} = (1 - momentum) * var_{pre} + momentum * var_t

除BN之外,还有LN,IN,GN等normaliazation方法,与BN的区别仅仅是均值和方差的计算方式不同。

Layer Normalization

起因:BN不适用于变长的网络,如RNN
思路:逐层计算均值和方差

  • 不再有meanrunningmean_{running}varrunningvar_{running}
  • gamma和beta为逐元素的

nn.LayerNorm

  • normalized_shape:该层特征形状
  • eps:分母修正项
  • elementwise_affine:是否需要affine
    transform

instance Normalization

起因:BN在图像生成(Image Generation)中不适用
思路:逐Instance(channel)计算均值和方差

nn.InstanceNorm

  • num_features:一个样本特征数量(最重要)
  • eps:分母修正项
  • momentum:指数加权平均估计当前mean/var
  • affine:是否需要affine transform
  • track_running_stats:是训练状态,还是测试状

Group Normalization

起因:小batch样本中,BN估计的值不准
思路:数据不够,通道来凑
应用场景:大模型(小batch size)任务

  • 不再有running_mean和running_var
  • gamma和beta为逐通道(channel)的

nn.GroupNorm

  • num_groups:分组数
  • num_channels:通道数(特征数)
  • eps:分母修正项
  • affine:是否需要affine transform

Dropout

注意:由于数据尺度变化,测试时,所有权重乘以1-p

nn.Dropout

  • p:被舍弃概率,失活概率,默认为0.5

注意:在Pytorch中为了测试更加方便,在训练时权重已经预先乘以11p\frac{1}{1-p},故如果使用Pytorch中的nn.Dropout,则测试时,不需要再乘上(1-p)