本章内容(使用Python3.6实现)
- 决策树简介
- 在数据集中度量一致性
- 使用递归构造决策树
- 使用matplotlib绘制树形图
关于决策树,我们首先讨论构造决策树的方法,以及如何编写构造树的Python代码;接着提出一些度量算法成功率的方法;最后使用递归建立分类器,并且使用Matplotlib绘制决策树图。构造完成决策树分类器之后,我们将输入一些隐形眼睛的处方数据,并由决策树分类器预测需要的镜片类型。
3.1 决策树的构造
决策树
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配问题。
适用数据类型:数值型和标称型
在构造决策树时,我们需要解决的第一个问题就是,当前数据集上哪个特征在划分数据分类时起决定性作用。为了找到决定性的特征,划分出最好的结果,我们必须评估每个特征。
创建分支的伪代码函数createBranch() 如下所示:
检测数据集中的每个子项是否属于同一分类:
if so return 类标签;
else:
寻找划分数据集的最好特征
划分数据集
创建分支节点
for 每个划分的子集
调用函数creatBranch并增加返回结果到分支节点中
return 分支节点
决策树的一般流程
1)收集数据:可以使用任何方法
2)准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
3)分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
4)训练算法:构造树的数据结构。
5)测试算法:使用经验树计算错误率。
6)使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。
本书采用ID3算法划分数据集,该算法处理如何划分数据集,何时停止划分数据集。每次划分数据集时我们只选取一个特征属性,如果训练集中存在20个特征,第一次选择哪个特征作为划分的参考属性?回答这个问题我们必须采用量化的方法判断如何划分数据。以下表3-1 数据为例。
不浮出水面是否可以生存 | 是否有脚蹼 | 属于鱼类 | |
1 | 是 | 是 | 是 |
2 | 是 | 是 | 是 |
3 | 是 | 否 | 否 |
4 | 否 | 是 | 否 |
5 | 否 | 是 | 否 |
3.1.1 信息增益
划分数据集的大原则是:将无序的数据变得更加有序。组织杂乱无章数据的一种方法就是使用信息论度量信息。在划分数据集之前之后信息发生的变化称为信息增益,知道如何计算信息增益,我们就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。
, 其中是选择该分类的概率。
,其中n是分类的数目。
# trees.py
# 使用Python计算信息熵
from math import log
def calcShannonEnt(dataset):
numEntries = len(dataset)
labelCounts = {}
for featVec in dataset:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2)
return shannonEnt
# 创建一个数据字典
def createDataSet():
dataSet = [[1,1,'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no'],]
labels = ['no surfacing','flippers']
return dataSet,labels
熵越高,则混合的数据也越多。得到熵之后,我们就可以按照获取最大信息增益的方法划分数据集。
3.1.2 划分数据集
我们学习了如何度量数据集的无序程度,分类算法除了需要测量信息熵,还需要划分数据集,度量划分数据集的熵,以便判断当前是否正确地划分了数据集。我们将对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式。
# 划分数据集,其中axis为划分数据集的特征,value为需要返回的特征的值
def splitDataSet(dataSet,axis,value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
接下来我们将遍历整个数据集,循环计算香农熵和splitDataSet()函数,找到最好的特征划分方式。
# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0])-1 #数据集包含特征的总数
baseEntropy = calcShannonEnt(dataSet) #原始数据集的信息熵
bestInfoGain = 0.0; bestFeature = -1 # 最佳信息增益,最佳划分特征
for i in range(numFeatures): # iterate over all the features
featList = [example[i] for example in dataSet]
# create a list of all the examples of this feature
uniqueVals = set(featList) # 集合内元素不重合,这表示某特征的特征值所组成的集合
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy # calculate the info gain; ie reduction in entropy
if (infoGain > bestInfoGain): # compare this to the best gain so far
bestInfoGain = infoGain # if better than current best, set to best
bestFeature = i
return bestFeature # returns an integer
信息增益是熵的减少或者是数据无序度的减少,大家肯定对于将熵用于度量数据无序度的减少更容易理解。
3.1.3 递归构建决策树
第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。因此我们可以采用递归的原则处理数据集。递归结束的条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。然而,如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时,我们需要决定如何定义该叶子节点,在这种情况下,我们通常会采用多数表决的方法决定该叶子节点的分类。
# 引入operator,多数表决决定该叶子节点的分类
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(),
key=operator.itemgetter(1),reverse=True)
# sortedClassCount = sorted(classCount,key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
# 创建树的函数代码,采用递归的原则处理数据集
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0] # 类别完全相同则停止继续划分
if len(dataSet[0]) == 1:
return majorityCnt(classList) #遍历完所有特征时返回出现次数最多的类别
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLable = labels[bestFeat]
myTree = {bestFeatLable:{}}
del (labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLable][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
3.2 在Python中使用Matplotlib注解绘制树形图
3.2.1 Matplotlib注解
# treePlotter.py
# 使用文本注解绘制树节点
import matplotlib.pyplot as plt
#定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
# 绘制带箭头的注解
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
xytext=centerPt,textcoords='axes fraction',
va='center',ha='center',bbox=nodeType,
arrowprops=arrow_args)
#创建节点
def createPlot():
fig = plt.figure(1,facecolor='white')
fig.clf()
createPlot.ax1 = plt.subplot(111,frameon=False)
plotNode('决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
plotNode('叶节点',(0.8,0.1),(0.3,0.8),leafNode)
plt.show()
运行结果如下图所示:
3.2.2 构造注解树
构造一个完整的树需要掌握一些技巧。虽然我们有x,y坐标,但是如何放置所有的树节点却是个问题。我们必须知道有多少个叶节点,以便正确确定x轴的长度;知道树有多少层,以便可以正确确定y轴的高度。这里我们定义两个新函数getNumLeafs()和getTreeDepth(),来获取叶结点的数目和树的层数。
完整代码如下:
# 使用文本注解绘制树节点
import matplotlib.pyplot as plt
#定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
# 获取叶结点的数目和树的层数
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
#test to see if the nodes are dictonaires, if not they are leaf nodes
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
# 绘制带箭头的注解
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
xytext=centerPt,textcoords='axes fraction',
va='center',ha='center',bbox=nodeType,
arrowprops=arrow_args)
# 在父子节点间填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
# createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center",
# rotation=30)
def plotTree(myTree,parentPt,nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0] #计算宽与高
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # no ticks
# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW;
plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
# #创建节点
# def createPlot():
# fig = plt.figure(1,facecolor='white')
# fig.clf()
# createPlot.ax1 = plt.subplot(111,frameon=False)
# plotNode('决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
# plotNode('叶节点',(0.8,0.1),(0.3,0.8),leafNode)
# plt.show()
# 输出预先存储的树信息,避免了每次测试代码时都要从数据中创建树的麻烦
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1:
{'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
if __name__ == '__main__':
# print(createPlot())
myTree = retrieveTree(0)
print(getNumLeafs(myTree))
print(getTreeDepth(myTree))
print(createPlot(myTree))
myTree['no surfacing'][3] = 'maybe'
print(myTree)
print(createPlot(myTree))
如图所示:
3.3 测试和存储分类器
3.3.1 测试算法:使用决策树执行分类
# 使用决策树的分类函数
def classify(inputTree,featLabels,testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in list(secondDict.keys()):
if testVec[featIndex] == key:
if type(secondDict[key]).__name__=='dict':
classLabel = classify(secondDict[key],featLabels,testVec)
else:
classLabel = secondDict[key]
return classLabel
3.3.2 存储决策树
# 使用pickle模块存储决策树
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'wb') # 用pickle序列化后的是二进制,所以此处用wb,
# 以二进制写入,不然默认以字符串写入会出错
fw.write(pickle.dumps(inputTree))
fw.close()
def grabTree(filename):
import pickle
fr = open(filename, 'rb') # 用'rb'
data = pickle.loads(fr.read())
fr.close()
return data
# 也可以使用json模块
def storeTree(inputTree, filename):
import json
fw = open(filename, 'w') # 只需要'w'
fw.write(json.dumps(inputTree))
fw.close()
def grabTree(filename):
import json
fr = open(filename, 'r') # 只需要'r'
data = json.loads(fr.read())
fr.close()
return data
示例一:使用决策树预测隐形眼镜类型
'''
@Project -> File :ML_in_action -> testLenses
@IDE :PyCharm
@Author :NatW
@Date :2019/11/1 21:16'''
import trees
import treePlotter
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels=['age','prescript','astigmatic','tearRate']
lensesTree = trees.createTree(lenses,lensesLabels)
print(lensesTree)
print(treePlotter.createPlot(lensesTree))
结果图如下:
倘若匹配选项过多,称之为“过度匹配”。为了减少过度匹配问题,我们可以裁剪决策树,去掉一些不必要的叶子节点。如果叶子节点只能增加少许信息,则可以删除该节点,将它并入到其他叶子节点中。
小结:本章使用的算法成为ID3,它是一个很好的算法但并不完美。ID3算法无法直接处理数值型数据,尽管我们可以通过量化的方法将数值型数据转化为标称型数值,但是如果存在太多的特征划分,ID3算法仍然会面临其他问题。还有其他的决策树的构造算法,最流行的有C4.5和CART,在讨论回归问题时会介绍CART算法。
决策树分类器就像带有终止块的流程图,终止块表示分类结果。开始处理数据集时,我们首先需要测量集合中数据的不一致性,也就是熵,然后寻找最优方案划分数据集,直到数据集中的所有数据属于同一分类。ID3算法可以用于划分标称型数据集。