剪枝

剪枝

本文主要介绍了一种压缩网络的方法:剪枝。剪枝就是去除网络中一些不重要的神经元,大大降低了计算量和权重数量,提高了网络运行效率。

01

决策树剪枝

决策树剪枝在周志华老师的《机器学习》一书中给了简洁易懂的介绍,此处主要引自书中。

剪枝技术最先被用于决策树,是决策树对付过拟合的重要办法。在决策树学习中,为了尽可能正确分类训练样本,结点划分过程将不断重复,导致结点过多。这时就可能因为决策树学的太好了,以至于把训练集自身的一些特点当作所有数据共有的属性而导致过拟合。

决策树剪枝包括预剪枝和后剪枝。预剪枝是在决策树生成过程中,对将要划分的结点进行评估,如果不能带来决策树泛化性能的提升,那么就不划分这个结点。后剪枝是生成了决策树之后进行的,从非叶结点开始考察,如果将该结点对应的子树替换为叶结点,能够提高精度,那么就用叶结点来替代这课子树。

如果仅仅有训练数据集,那么在生成决策树的过程中就会不断生成结点。因而为了进行剪枝,数据集分为两部分,一是训练集,另外一个验证集。通过训练集来生成可能的结点,然后再通过验证集来进行过滤。如果生成的结点在验证集上不能给出好的准确率,那就去除这个结点。但是预剪枝有欠拟合的风险。

后剪枝从生成的完全决策树T0低端开始,不断剪枝,形成一系列子树{T0, T1,…},然后通过交叉验证法得到最优的子树。在进行剪枝得到子树的时候,是通过如下损失函数来进行判定的:

剪枝

其中C(T)是任意树的损失函数,正则项是为了在树的大小和准确率之间做一个权衡。从该式子看出如果a=0,那么整体数是最优的,如果a趋近于无穷大,那么以单节点为树的是最优的。那么可以这样做:从较小的a开始,不断增大,得到一系列区间,这些区间对应着不同的子树。

具体来说,首先从子树T0,开始考察每个单节点t,然后计算以单节点为树的损失函数和以该结点为根的子树,如果二者相等,说明可以去除该结点对应的子树。用公式表示:

剪枝

对每个结点计算此数值,然后选择最小的,去除以t为根节点的子树,得到新的子树T1。再以T1为新的树,重复上述过程。得到一系列树,并且对应着一系列a区间:[ai, a(i+1)]。

第二步就是在验证集上考察每个子树的平方误差或者基尼指数,选择最优子树。

02

深度神经网络剪枝

剪枝

剪枝的目的是为了减小神经网络的参数数量,降低计算量。在实际应用中,一个大的神经网络因为应用场景的局限,存在一定的冗余。剪枝技术就是降低这种冗余的办法。剪枝的核心思想就是:在保证精确度没有较大损失的前提下,尽量减少神经元的连接。

剪枝过程为:首先在一个数据集上完成对一个神经网络的训练,接下来使用相似但是不同的数据集用于剪枝。针对训练好的神经网络,按照一定的标准计算需要剪枝的神经元,去除影响最小的神经元,然后在新的数据集上进行fine-tuning。不断进行这样的过程,一直到神经元被减少到要求达到的目标,而精度也在损失的范围内。

剪枝的标准有很多,最简单的是暴力剪枝。假设C{D|W}为数据集D,参数集W的代价函数,最小化以下L1-norm:

剪枝

这是一个组合问题,因为每个参数都有保留和去除两种情况,所以总共需要进行计算的次数达到2^|W|次。

Oracle剪枝技术是每次去除一个神经元,观察代价函数的变化,如果代价函数更小了就说明去除这个神经元是有利于精度提高的。那就去除这个神经元。这种方式的计算次数是W。计算量也很大。

剪枝

最小权重方法:考察每个权重的大小会消耗大量的计算,那么为什么不去计算一系列相关的权重大小呢?定义一个计算kernel权重L2-morm大小的量:

剪枝

如果这个l2-norm较小,那么就说明这个kernel对结果的影响较小,就可以舍弃。这种方法避免了像之前需要重新计算代价函数的过程,在fine-tuning中就可以完成,而不需要再增加额外的计算。

**函数方法:在一个神经网络中,**函数的作用是过滤一些不重要的量。因此通过**函数来判断权重的影响也是一种很好的方法。那么用平均值来进行判定:

剪枝

也可以通过**值的标准差:

剪枝

互信息:如果一个权重对结果的影响很重要,那么说明它和其它变量有很强的关系。那么就可以通过这种方式去判断。熵用来表示信息量的多少,通过熵的变化可以作为判定:

剪枝

其中IG表示互信息,而H为熵。H(x,y)就是联合熵。

Taylor展开:如果将代价函数进行taylor展开,忽略高阶项,可以降低通过代价函数最小化来剪枝的技术。假设hi为来自i参数的输出结果,如果其为0,表示这个参数可以被去除。那么去除后对代价函数的影响用公式表示:

剪枝

如果对代价函数的影响较小,那么说明这个参数是冗余的,就可以被优化掉。

THE END

剪枝在决策树中用于解决过拟合问题,同时也可以降低计算量。剪枝在深度学习中主要是为了减少神经网络中冗余,减少参数,加速硬件的计算。

剪枝

剪枝

剪枝