[DNN] 尝试理解深度神经网络的Large-batch魔咒
[DNN] 尝试理解深度神经网络的Large-batch魔咒
1 天前
最近贵司的“一小时训练ImageNet”论文在国内外各种刷屏(https://research.fb.com/publications/imagenet1kin1h/),看了一下,确实非常实用主义的文章,介绍很多有用的trick,包括系统实现上的很多坑都覆盖到了。其中谈到加速训练的难点之一是:需要用到更大的mini-batch size,而这通常会降低准确率,所以他们通过linear-scaling learning rate解决了这个问题。看到这里我对于这个难点产生了疑问——batch size越大,不应该训练的方差越小,随机性越小,从而能够更准确地拟合数据集么?
从一个对深度学习接触不多的人(比如我)的角度,这点确实有点反直觉。当batch-size不断增大,直到跟数据集一样大的时候,SGD (Stochastic Gradient Descent)就变成了最朴素的GD,一次梯度更新会扫描一遍所有的数据来算梯度。看教科书和在CMU上Machine Learning的时候被灌输的理念都是:SGD相对于GD,或者小的batch相对于大的batch,有助于更快收敛,但是准确度会下降。为什么到了深度神经网络这里就反过来了呢?
我的第一猜想是:神经网络的函数空间非常non-convex,所以mini-batch越小就越容易不断跳出local minima,寻找更好的最小值。但是自己马上感觉这个猜想有很多漏洞,不能自圆其说,所以我去查证了一下其他人的分析——Facebook的论文原文有提及过这个问题,以及ICLR 2017上也有一篇论文针对这个问题分析了一下。有趣的事,两篇文章的观点并不相同,Facebook的论文还轻踩了对方一下说“根据我们跑的实验这事不是你们说的那样儿的”。由于ICLR 的论文先出来,我们先看看它怎么说:
ICLR 2017: ON LARGE-BATCH TRAINING FOR DEEP LEARNING: GENERALIZATION GAP AND SHARP MINIMA https://openreview.net/pdf?id=H1oyRlYgg
这篇文章主要研究了“为什么Large batch size会让错误率提高”的问题,提出了四个可能的猜想:
(i) LB methods over-fit the model;
(ii) LB methods are attracted to saddle points;
(iii) LB methods lack the explorative properties of SB methods and tend to zoom-in on the minimizer closest to the initial point;
(iv) SB and LB methods converge to qualitatively different minimizers with differing generalization properties.
然后通过实验,得出了支持(iii)和(iv)的证据。也就是说,主要是两点原因:
1) LB (Large-Batch) 方法探索性太差,容易在离起始点附近很近的地方停下来
2) LB和SB由于训练方式上的差异,最终会导致它们最终收敛的点具有一些数学属性的差异
#1 很好理解,跟我前面的猜想有点类似。这里着重谈谈#2 - 文章谈到,LB方法会收敛到Sharp-minimum,而SB方法会收敛到Flat-minimum。这两种minimum的差别如图所示:
在同样的Bias下,明显Flat的曲线比Sharp的曲线更加接近真实情况,所以Flat Minimum的generalization performance更好。
然后,基于这个假设,他们给出的解决方案是:先用SB方法训练几个epoch,让它先探索一下,找到一个比较Flat的区域,再用LB方法慢慢收敛到正确的地方。论文给出了performance vs. # of epoch trained with SB,但个人感觉不是很有说服力。。。
Facebook: Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour
再回到Facebook这篇文章,作者认为,LB之所以不work,不是因为上面那篇论文提到的泛化能力的问题,而主要是一个optimization issue(我的理解是优化过程/优化算法的问题)。文章没有给出理论分析, 而是直接给出了实验数据:首先,这篇论文是基于“Linear Scaling Learning Rate”来做的,简单来说,假如说原来batch size是256,learning rate是0.1;那么当把batch size设成8192的时候,learning rate就设成3.2 。batch size翻多少倍,learning rate就翻多少倍。然后,基于这个方法,论文作者发现,如果用LB方法,刚开始就用很大的learning rate的话,效果其实是很差的;但是,只要刚开始把LR设小点,后来逐步把LR提高到正常的大小,那么效果拔群,LB能够得到跟SB几乎一毛一样的training curve,以及基本相同的准确度。
基于这个观察,作者认为,LB不work的主要原因是
large minibatch sizes are challenged by optimization difficulties in early training
(至于为什么,这个跟Linear Scaling Learning Rate的assumption有关:简单来说,就是Linear Scaling Learning Rate这个trick是基于一定的assumption的,而这个assumption在网络权重急剧变化的时候——也就是刚开始训练的时候——是不成立的。所以,一开始就应用那么大的learning rate会出事。我解释的不是很清楚,具体可以去看原论文)
总结
上篇两篇论文各有千秋:ICLR那篇着重理论分析,用漂亮的实验验证了Sharp-minimum和Flat-minimum的区别,启发性非常大,但是给出的解决方案不是很令人信服;Facebook这篇直接从实战经验出发,实验和解释都比较令人信服,不过理论上相对弱些。
对于两者的Claim,其实不能说谁对谁错,因为两者的实验方法不一样;ICLR那篇没有应用Linear Scaling Learning Rate而是直接应用了ADAM来作为optimizer,得出的结果跟Facebook的肯定不能直接相比。如果ICLR那篇论文的作者可以使用Facebook的方法论重新跑实验的话,说不定得出的结论会有很大不同。甚至说,双方的结论其实不完全互斥,而是可以被统一成一个理论(比如我现在拍脑袋想的:刚开始训练的时候,Large-batch得出来的梯度不准确,所以如果设的learning rate太大,就更加容易陷入Sharp-minimum出不来,从而影响到后面的优化,之类之类的)。
「真诚赞赏,手留余香」
转载于:https://my.oschina.net/airship/blog/919864