决策树 机器学习实战
决策树
本篇教程将结合西瓜书来介绍。
一句话搞懂决策树是干啥的,还记得曾经风靡一时的一类游戏,“测测是你什么性格or什么类型的人”等测试类游戏,我们会有很多道选择题,比如每道题选择不同选项会跳到不同的问题上去,最终得出一个结果。这就是决策树。
一般一棵树包括根节点,叶子节点,中间节点和分枝。就从上图就能看出来。我们的任务就是构建这样的一棵树。
那么首先问题来了, 我这个根节点是怎么构建的。从感觉上来讲,你这个根节点应该是最能把数据做一个划分的分类。比如你对顾客的划分,最重要的是看有没有购买能力,如果购买能力很低的话,你跟本就不会再往后考虑了,所以我们的目的就是按照这样的思路进行划分。这有一个很专业的名称,叫“信息增益”。也就是越能对数据集进行划分的特征,该特征的信息增益越大。莫得问题吧。
为了能够量化这个信息增益,先提一个概念叫信息增益,也叫熵。熵这个词我是在高中化学中接触到的,它表示分子的混乱程度,在这里就是表示信息的混乱程度,熵越大,那么这些数据的混乱程度就越大。
比如现在有一坨信息,17条关于西瓜好不好的数据。8条为好瓜的信息,9条为不好瓜信息。那么这坨信息的信息熵就等于
对于这个公式,有个感性的想法,下图是y=log2x的函数。那么如果样本比例的越少,pi越小,则logpi就越大,反应出看就是混乱程度越大。但同时会乘所占比例。所以这样会推理出,一份信息集,有一个分类占的很多,则混乱程度小,信息熵就小,所以按照公式算出来的结果也是这样的。而如果信息集的数据分散在每个类别中,则混乱程度就会打,信息熵大,公式的结果也是同理。
import numpy as np
import math
import matplotlib
import matplotlib.pyplot as plt
x=np.arange(0.01,3,0.01)
y1 = [math.log2(a) for a in x]
plt.plot(x, y1, linewidth=2, color='#007500', label='log1.5(x)')
plt.plot(0.1, math.log2(0.1),'ro')
plt.plot(0.9, math.log2(0.9), 'ro')
plt.show()
我们再回到上述的信息增益上来。我为了选出分类能力最强的特征,那么我们就需要挨个计算按照每个特征划分之后,信息熵的大小。比如一坨信息集D的信息熵Ent(D)=0.9,按照类型1划分之后,信息熵Ent(D1)=0.7,按照类型2划分后信息熵Ent(D2)=0.5,类型3Ent(D3)=0.3。分别让Ent(D)减去这三个结果,发现减去类型3之后的值是最大的,也就是类型三,能够把熵减少到最小,所以类型三的信息增益最大。因而我们应该把类型三作为根节点进行分类。然后我们将信息三的分类结果再进行划分,再计算各种类型的信息熵,再求信息增益,再划分…这样就可以演化为一个递归程序。但是到此并没有考虑完全,下面我们通过程序实现上述过程,在实际代码中我们会发现欠缺的内容。
首先,如下两段函数的意思分别是求信息熵和创建一个数据集。之后通过一个程序,我们先对创建的数据集计算信息熵,得到0.97。当我们把这个数据集其中一个元素改变,效果是使得信息集更加混乱了,那么我们看到求得的信息熵变大了。变成1.37.
from math import log
def calcShannonEnt (dataSet):
numEntries = len(dataSet)
labelConuts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelConuts.keys():
labelConuts[currentLabel]=0
labelConuts[currentLabel]+=1
shannonEnt =0.0
for key in labelConuts:
prob = float(labelConuts[key])/numEntries
shannonEnt -= prob*log(prob,2)
return shannonEnt
def createDataSet ():
dataSet = [[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']
]
labels=['no surfacing','flippers']
return dataSet,labels
myDat,labels = createDataSet()
print(calcShannonEnt(myDat))
myDat[0][-1]='maybe'
print(calcShannonEnt(myDat))
0.9709505944546686
1.3709505944546687
这一步是划分数据集,dataset就是我们的等待划分的数据集,axis是换分的那一列,value表示这列的数值。
def splitDatSet(dataSet,axis,value):
retDataSet =[]
for featVec in dataSet:
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis+1:])
retDataSet.append(reduceFeatVec)
return retDataSet
splitDatSet(myDat,0,1)
[[1, 'maybe'], [1, 'yes'], [0, 'no']]
这里,splitDatSet(myDat,0,1),0表示第一列,1表示数据为1的值。然后我们把第一列刨去,剩下的橘黄色列的信息。(因为我上一步把第一个数据的yes环城路maybe,所以输出结果是maybe)
选择最好的数据集划分方式,这个意思也就是计算信息增益最大的。baseEntrpoy 是数据集的信息增益。
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0])-1
baseEntrpoy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList) #这时看一个特征都有哪些分类,比如颜色特征,有红橙黄绿……
newEntropy = 0.0
for value in uniqueVals :#这个循环就是算出该特征的信息熵
subDataSet = splitDatSet(dataSet,i,value)
pro = len(subDataSet)/float(len(dataSet))
newEntropy += pro*calcShannonEnt(subDataSet)
infoGain = baseEntrpoy-newEntropy#用基础的减去上面算出的某特征的信息熵,求出信息增益
if (infoGain>bestInfoGain):#求信息增益最大值
bestInfoGain=infoGain
bestFeature=i
return bestFeature
chooseBestFeatureToSplit(myDat)
0
返回的这个‘0’,是第0个特征。
下面,majorityCnt这个函数先不说是啥意思。先看下面的createTree
import operator
def majorityCnt (classList):
classCount={}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] +=1
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
我们就要创建一棵树了。我们可以先从第11行查看我们先要创建一棵树,那么先要求出最好的特征,通过chooseBestFeatureToSplit函数,我们找到了最好的特征之后,把它加入到树中,之后我们要把这个特征从数据中去掉。所以在for循环中的递归步骤中,createTree函数的dataSet抛去了最佳特征。然后再作为一个数据集继续划分。
那么这样划分下去,终将会得有个终止的叶子节点吧。那么就会出现两种情况,一种情况就是子类全是一个结果。还是看下图,购买能力高的属性,都是vip所以这些数据就被化为一个叶子节点,也就是函数中的第一个if中的内容。
第二种情况就是,我所有的特征都划分完了,还有没有划分进去的数据,那么这些数据就是统计每个结果的频率,最后排序,返回出现最多的分类名称。也就是上面majorityCnt 函数做的工作。
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
print(classList)
myTree={}
if classList.count(classList[0]) == len(classList):#子类都是一个结果
print(myTree)
return classList[0]
if len(dataSet[0]) ==1:
print(myTree)
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLables = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDatSet(dataSet,bestFeat,value),subLables)
print(myTree)
return myTree
myDat,labels = createDataSet()
myTree = createTree(myDat,labels)
print(myTree)
['yes', 'yes', 'no', 'no', 'no']
['no', 'no']
{}
['yes', 'yes', 'no']
['no']
{}
['yes', 'yes']
{}
{'flippers': {0: 'no', 1: 'yes'}}
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
创建好一个树之后,也就是训练好训练集之后,我们就要开始测试进行分类了。如下函数classify
def classify(inputTree,featLabels,testVec):
firstStr = list(inputTree.keys())[0] #第一个特征名称;‘no surfacing’
secondDict = inputTree[firstStr] #子树。secondDict 的内容就是抛去根节点,剩下的边
featIndex = featLabels.index(firstStr) #选取第一个特征的index值
for key in secondDict.keys():# 把子树的属性key 罗列出来
if (testVec[featIndex]== key): #测试集的数据和子树的属性对比,相等的化进入下一步
if type(secondDict[key]).__name__=='dict': #确定之后,如果子树下面还是树,就继续递归
classLabel=classify(secondDict[key],featLabels,testVec)
else:classLabel=secondDict[key]#不是树是节点的话,得到分类结果
return classLabel
mydat,labels=createDataSet()
result = classify(myTree,labels,[1,0])
print(result)
result = classify(myTree,labels,[1,1])
print(result)
no
yes
下面就是把我们的训练结果进行存储。
def storeTree(inputTree,filename):
import pickle
fw=open(filename,'wb')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename,'rb')
return pickle.load(fr)
storeTree(myTree,'res.txt')
grabTree('res.txt')
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
至此,一颗简单的决策树的代码就到此为止了。为什么说是简单,因为还有很多内容没有提到,因为首先从算法上。我们介绍的信息增益的方法属于ID3算法,采用的信息增益度量存在一个缺点,它一般会优先选择有较多属性值的Feature,因为属性值多的Feature会有相对较大的信息增益?(信息增益反映的给定一个条件以后不确定性减少的程度,必然是分得越细的数据集确定性更高,也就是条件熵越小,信息增益越大)后面的内容中,将会介绍改进的方法,C4.5和CART算法。同时从分类上,这种方式的分类会出现过拟合的现象,也就是训练的分类器十分刻薄,耿直,我们需要这个分类器更灵活一些。那么后面同样会介绍避免这样的过拟合的减枝内容。