机器学习笔记(6)——C4.5决策树中的剪枝处理和Python实现

1. 为什么要剪枝

还记得决策树的构造过程吗?为了尽可能正确分类训练样本,节点的划分过程会不断重复直到不能再分,这样就可能对训练样本学习的“太好”了,把训练样本的一些特点当做所有数据都具有的一般性质,cong从而导致过拟合。这时就可以通过剪枝处理去掉yi一些分支来降低过拟合的风险。

剪枝的基本策略有“预剪枝”(prepruning)和“后剪枝”(post-pruning):

预剪枝是在决策树的生成过程中,对每个结点划分前先做评估,如果划分不能提升决策树的泛化性能,就停止划分并将此节点记为叶节点;

后剪枝是在决策树构造完成后,自底向上对非叶节点进行评估,如果将其换成叶节点能提升泛化性能,则将该子树换成叶节点。

那么怎么判断泛化性能是否提升呢?这时需要将数据集分为训练集和验证集,利用训练集构造决策树,利用验证集来评估剪枝前后的验证集精度(即正确分类的比例)。

下面我们把之前的西瓜数据集划分为训练集和验证集,之后在分别详细演示预剪枝和后剪枝的处理过程。

机器学习笔记(6)——C4.5决策树中的剪枝处理和Python实现

首先利用训练集数据,构造一个未做剪枝处理的决策树,以便于与剪枝后的决策树做对比。

注意:这里构造的决策树与《机器学习》中的不一样,因为色泽、根蒂、脐部三个属性的信息增益是相等的,都可以作为最优划分属性。

机器学习笔记(6)——C4.5决策树中的剪枝处理和Python实现

2. 预剪枝

我们先学习预剪枝的过程:

(1)根据信息增益准则,选取“色泽”作为根节点进行划分,会产生3个分支(青绿、乌黑、浅白)。

机器学习笔记(6)——C4.5决策树中的剪枝处理和Python实现

对根节点“色泽”,若不划分,该节点被标记为叶节点,训练集中正负样本数相等,我们将其标记为“是”好瓜(当样本最多的类不唯一时,可任选其中一类,我们默认都选正类)。那么训练集的7个样本中,3个正样本被正确分类,验证集精度为3/7*100%=42.9%。

对根节点“色泽”划分后,产生图中的3个分支,训练集中的7个样本中,编号为{8,11,12,4}的4个样本被正确分类,验证集精度为4/7*100%=57.1%。

于是节点“色泽”应该进行划分。

(2)再看“色泽”为乌黑这个分支,如果对其进行划分,选择“根蒂”作为划分属性

机器学习笔记(6)——C4.5决策树中的剪枝处理和Python实现

对“根蒂”这个分支节点,如果不划分,验证集精度为57.1%。如果划分,进入此分支的两个样例{8,9},编号为8的样例分类正确,编号为9的样例分类错误,所以对整棵树来说编号为{4,8,11,12}的4个样本分类正确,验证集精度仍为57.1%。

按预剪枝的策略,验证集精度没有提升的话,不再划分。

(3)“色泽”为浅白的分支只有一个类别,无法再划分。再评估“色泽”为青绿的分支,如果对其进行划分,选择“敲声”作为划分sh属性,产生3个分支。

机器学习笔记(6)——C4.5决策树中的剪枝处理和Python实现

对“敲声”这个分支节点,如果不划分,验证集精度为57.1%。如果划分,进入此分支的两个样例{4,13},编号为4的样例分类错误,编号为13的样例也分类错误,所以对整棵树来说编号为{8,11,12}的4个样本分类正确,验证集精度仍为42.9%。

按预剪枝的策略,验证集精度没有提升的话,不再划分。

因此,通过预剪枝处理生成的树只有一个根节点,这种树也称“决策树桩”。

机器学习笔记(6)——C4.5决策树中的剪枝处理和Python实现

优缺点分析:预剪枝使得决策树的很多分支没有展开,可以降低过拟合的风险,减少决策树的训练时间和测试时间。但是,尽管有些分支的划分不能提升泛化性能,但是后续划分可能使性能显著提高,由于预剪枝没有展开这些分支,带来了欠拟合的风险。

3. 后剪枝

后剪枝是在决策树构造完成后,自底向上对非叶节点进行评估,为了方便分析,我们对树中的非叶子节点进行编号,然后依次评估其是否需要剪枝。

对于完整的决策树,在剪枝前,编号为{11,12}的两个样本被正确分类,因此其验证集精度为2/7*100%=28.6%。

(1)第一步先考察编号为4的结点,如果剪掉该分支,该结点应被标记为“是”。进入该分支的验证集样本有{8,9},样本8被正确分类,对整个验证集,编号为{8,11,12}的样本正确分类,因此验证集精度提升为42.9%,决定剪掉该分支。

机器学习笔记(6)——C4.5决策树中的剪枝处理和Python实现

(2)再来考察编号为3的结点,如果剪掉该分支,该结点应标记为“是”,进入该分支的样本有{4,13},其中样本4被正确分类,对整个验证集,编号为{4,8,11,12}的样本正确分类,因此验证集精度提升为57.1%,决定剪掉该分支。

机器学习笔记(6)——C4.5决策树中的剪枝处理和Python实现

(3)再看编号为2的结点,如果剪掉该分支,该结点应标记为“是”,进入该分支的样本有{8,9},其中样本8被正确分类,样本9被错误分类,对整个验证集,编号为{4,8,11,12}的样本正确分类,验证集精度仍为57.1%,没有提升,因此不做剪枝。

机器学习笔记(6)——C4.5决策树中的剪枝处理和Python实现

(4)对编号为1的结点,如果对其剪枝,其验证集精度为42.9%(同预剪枝的第一步),因此也不剪枝。

后剪枝得到的决策树就是第(3)步的样子。

优缺点分析:后剪枝通常比预剪枝保留更多的分支,欠拟合风险小。但是后剪枝shi是在决策树构造完成后进行的,其训练时间的开销会大于预剪枝。

4. 后剪枝的Python实现

由于后剪枝的泛化能力高于预剪枝,这里只对后剪枝编程。为了方便评估,上述过程并没有包含连续属性,但是C4.5决策树是可以处理连续sh属性的,因此我们在编程中把连续属性也一并考虑进去。

def postPruningTree(inputTree, dataSet, data_test, labels, labelProperties):
    """ 
    type: (dict, list, list, list, list) -> dict
    inputTree: 已构造的树
    dataSet: 训练集
    data_test: 验证集
    labels: 属性标签
    labelProperties: 属性类别
    """
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    classList = [example[-1] for example in dataSet]
    featkey = copy.deepcopy(firstStr)
    if '<' in firstStr:  # 对连续的特征值,使用正则表达式获得特征标签和value
        featkey = re.compile("(.+<)").search(firstStr).group()[:-1]
        featvalue = float(re.compile("(<.+)").search(firstStr).group()[1:])
    labelIndex = labels.index(featkey)
    temp_labels = copy.deepcopy(labels)
    temp_labelProperties = copy.deepcopy(labelProperties)
    if labelProperties[labelIndex] == 0:  # 离散特征
        del (labels[labelIndex])
        del (labelProperties[labelIndex])
    for key in secondDict.keys():  # 对每个分支
        if type(secondDict[key]).__name__ == 'dict':  # 如果不是叶子节点
            if temp_labelProperties[labelIndex] == 0:  # 离散的
                subDataSet = splitDataSet(dataSet, labelIndex, key)
                subDataTest = splitDataSet(data_test, labelIndex, key)
            else:
                if key == 'yes':
                    subDataSet = splitDataSet_c(dataSet, labelIndex, featvalue,
                                               'L')
                    subDataTest = splitDataSet_c(data_test, labelIndex,
                                                featvalue, 'L')
                else:
                    subDataSet = splitDataSet_c(dataSet, labelIndex, featvalue,
                                               'R')
                    subDataTest = splitDataSet_c(data_test, labelIndex,
                                                featvalue, 'R')
            inputTree[firstStr][key] = postPruningTree(secondDict[key],
                                                       subDataSet, subDataTest,
                                                       copy.deepcopy(labels),
                                                       copy.deepcopy(
                                                           labelProperties))
    if testing(inputTree, data_test, temp_labels,
               temp_labelProperties) <= testingMajor(majorityCnt(classList),
                                                     data_test):
        return inputTree
    return majorityCnt(classList)

运行程序绘制出剪枝后的决策树,与上面人工绘制的一致。

机器学习笔记(6)——C4.5决策树中的剪枝处理和Python实现


参考:

周志华《机器学习》