决策树剪枝(cart剪枝)的原理介绍

最近看《统计学习方法》,最后有一部分讲到cart树的剪枝策略,个人觉得书上讲得比较晦涩难懂,现在结合个人理解和大家分享下,如有不正,敬请谅解!

1.决策树剪枝
决策树是一种分类器,通过ID3,C4.5和CART等算法可以通过训练数据构建一个决策树。但是,算法生成的决策树非常详细并且庞大,每个属性都被详细地加以考虑,决策树的树叶节点所覆盖的训练样本都是“纯”的。因此用这个决策树来对训练样本进行分类的话,你会发现对于训练样本而言,这个树表现完好,误差率极低且能够正确得对训练样本集中的样本进行分类。训练样本中的错误数据也会被决策树学习,成为决策树的部分,但是对于测试数据的表现就没有想象的那么好,或者极差,这就是所谓的过拟合(Overfitting)问题。

决策树剪枝主要可以分为两种:预剪枝和后剪枝

预剪枝(Pre-Pruning)

在构造决策树的同时进行剪枝。所有决策树的构建方法,都是在无法进一步降低熵的情况下才会停止创建分支的过程,为了避免过拟合,可以设定一个阈值,熵减小的数量小于这个阈值,即使还可以继续降低熵,也停止继续创建分支。但是这种方法实际中的效果并不好,因为在实际中,面对不同问题,很难说有一个明确的阈值可以保证树模型足够好,当然在xgboost和lightGBM里面有一些参数例如min_child_weight也算是设置了分裂节点的权重值,像xgboost之类已经把过拟合写进损失函数中了,因此不需要有剪枝的过程。

后剪枝(Post-Pruning)

后剪枝的剪枝过程是删除一些子树,然后用其叶子节点代替,在剪枝过程中, 将一些子树删除而用叶节点代替,这个叶节点所标识的类别用这棵子树中大多数训练样本所属的类别来标识。

决策树构造完成后进行剪枝。剪枝的过程是对拥有同样父节点的一组节点进行检查,判断如果将其合并,熵的增加量是否小于某一阈值。如果确实小,则这一组节点可以合并一个节点,其中包含了所有可能的结果。后剪枝是目前最普遍的做法。

剪枝作为决策树后期处理的重要步骤,是必不可少的。没有剪枝,就是一个完全生长的决策树,是过拟合的,需要去掉一些不必要的节点以使得决策树模型更具有泛化能力。

2.CART剪枝
树上的流程基本是这样:
对于原始的CART树A0,先剪去一棵子树,生成子树A1,然后再从A1剪去一棵子树生成A2,直到最后剪到只剩一个根结点的子树An,于是得到了A0-AN一共n+1棵子树。
然后再用n+1棵子树预测独立的验证数据集,谁的误差最小就选谁,大致的思路就是这样。
然后需要怎么去生成这些树呢,每棵树根据什么取选择剪枝,书上我一直看不太懂是下面这个:
决策树剪枝(cart剪枝)的原理介绍
书里对g(t)的解释是:它表示剪枝后整体损失函数减少的程度。。。,然而我当时并不是很理解,这么说应该剪掉以g(t)最大结点为根的子树,因为g(t)最大,那么剪枝后整体损失
函数减少程度也最大。但书中的算法却说优先剪去g(t)最小的子树,这就不懂为啥要从小开始剪枝,后来上网查了之后才懂。


实际上这个g(t)表示剪枝的阈值,即对于某一结点a,当总体损失函数中的参数alpha = g(t)时,剪和不剪总体损失函数是一样的(这可以在书中(5.27)和(5.28)联立得到)。
这时如果alpha稍稍增大,那么不剪的整体损失函数就大于剪去的。即alpha大于g(t)该剪,剪了会使整体损失函数减小;alpha小于g(t)不该剪,剪了会使整体损失函数增大。
(请注意上文中的总体损失函数,对象可以是以a为根的子树,也可以是整个CART树,对a剪枝前后二者的总体损失函数增减是相同的。)对于同一棵树的结点,alpha都是一样的,
当alpha从0开始缓慢增大,总会有某棵子树该剪,其他子树不该剪的情况,即alpha超过了某个结点的g(t),但还没有超过其他结点的g(t)。这样随着alpha不断增大,不断地剪枝,就
得到了n+1棵子树,接下来只要用独立数据集测试这n+1棵子树,试试哪棵子树的误差最小就知道那棵是最好的方案了。

By the way,书上第二版的最后貌似迭代过程是错的,应该是回到步骤2, 因为,每次都需要逐步增加g(t),这样遍历所有的g(t)之后就一定会剪成一颗只有一个根节点的树,这个大家可以好好理解。

这里有张图,说明为什么要选择小的g(t)值,看看应该就能懂

决策树剪枝(cart剪枝)的原理介绍


为什么要选择最小的g(t)呢?以图中两个点为例,结点1和结点2,g(t)2大于g(t)1, 假设在所有结点中g(t)1最小,g(t)2最大,两种选择方法:当选择最大值g(t)2,即结点2进行剪枝,
但此时结点1的不修剪的误差大于修剪之后的误差,即如果不修剪的话,误差变大,依次类推,对其它所有的结点的g(t)都是如此,从而造成整体的累计误差更大。反之,如果
选择最小值g(t)1,即结点1进行剪枝,则其余结点不剪的误差要小于剪后的误差,不修剪为好,且整体的误差最小。从而以最小g(t)剪枝获得的子树是该alpha值下的最优子树!

这下该明白了吧,反正我是明白了。。。有不懂可以留言

最后附上cart剪枝的正确流程图:

决策树剪枝(cart剪枝)的原理介绍