一、决策树模型
1.1 定义
分类决策树模型是一种描述对实例进行分类的树形结构。决策树由结点(node)和有向边(directed edge)组成。结点有两种类型:内部结点和叶结点。内部结点表示一种特征或属性,叶结点表示一个类。
下图是一个决策树模型,圆和方框分别表示内部结点和叶结点。
1.2 决策树学习
二、特征选择
特征选择在于选取对训练数据具有分类能力的特征,这样可以提高决策树学习的效率。通常特征选择的准则是信息增益或信息增益比。
2.1 熵的定义
2.2 条件熵
2.3 信息增益
信息增益(information gain)表示得知特征X的信息而使得类Y的信息的不确定性减少的程度。
2.4 信息增益比
信息增益值的大小是相对于训练数据集而言的,在分类问题中,训练数据集的经验熵大的时候,信息增益值就会偏大,反之,信息增益值会偏小。也就是说,以信息增益为划分训练数据集的特征,存在偏向于选择取值较多的特征的问题。而信息增益比则能解决这一问题。
三、ID3算法(Interative Dichotomiser 3,迭代二叉树3代)
3.1 代码实现
# -*- coding: utf-8 -*-
"""
Created on Fri Apr 13 18:50:19 2018
file name:tree.py
@author: lizihua
"""
from math import log
import operator
#输入一个数据集
def createDataSet():
dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
labels =['no surfacing','flippers']
return dataSet, labels
#计算给定数据集的熵
def calEntropy(dataSet):
numentries = len(dataSet)
labelCounts = {}
#用字典记录给定数据集中各个类出现的次数
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
#计算熵entropy
entropy = 0.0
for key in labelCounts:
#选择该类的概率
prob = float(labelCounts[key])/numentries
entropy -= prob*log(prob,2)
return entropy
#按照给定特征划分数据集
#dataSet:待划分的数据集、axis:划分数据集的特征、value:需要返回的特征的值
#假定dataSet有n组数据,有m个特征
def splitDataSet(dataSet, axis, value):
retDataSet = []
#featVec是1*m数组
for featVec in dataSet:
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis+1:])
#reduceFeatVec是一个1*(m-1)列表,剔除了featVec[axis]这个特征
retDataSet.append(reduceFeatVec)
return retDataSet
#选择最好的数据集划分方式
#信息增益准则:对训练数据集DataSet,计算其每个特征的信息增益,并比较大小,选择信息增益最大的特征
#信息增益g(dataSet,Feature)=H(dataSet)-H(dataSET|Feature)
def chooseBestFeatureToSplit(dataSet):
#特征数量
numFeatures = len(dataSet[0])-1
#baseEntropy即H(dataSet)
baseEntropy = calEntropy(dataSet)
#bestInfoGain和infoGain都是g,前者是g的最大值
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
#求这个特征的唯一分类结果,例如;该特征是年龄,其uniqueVals(类别)有:青年、中年、老年三种
uniqueVals = set(featList)
#计算其每个特征的经验条件熵newEntropy即H(dataSET|Feature)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob*calEntropy(subDataSet)
#计算其每个特征的信息增益
infoGain = baseEntropy - newEntropy
#找到最大的信息增益的特征
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
#采用多数表决的方法决定该叶子节点的分类
#与knn中的投票表决代码类似
def majorityCnt(classList):
#创建字典(key是类,value是该类的次数),
#然后按照value的值从大到小排序,最后返回value最大的对应的类(key值)
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
#递归构建决策树
#输入两个参数:数据集和标签列表(包含数据集中所有特征的标签)
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
#递归函数有两个终止条件:
#1.所有类标签完全相同时,返回该类标签
if classList.count(classList[0]) == len(classList):
return classList[0]
#2.使用完所有特征后,仍不能将数据集划分成仅包含唯一类别的分组
#当遍历完所欲特征时,dataSet[0]==1,即dataSet只剩一列,且该列是分类标签
#此时,返回出现次数最多的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
#选择根节点bestFeat,返回的是列索引
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
#利用字典变量myTree存储树的所有信息
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
#获得根节点bestFeat所在列的值
featValues = [example[bestFeat] for example in dataSet]
#获得根节点bestFeat所在列的值的集合
uniqueVals =set(featValues)
#递归创建决策树
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,
bestFeat,value),subLabels)
return myTree
#使用决策树的分类函数
def classify(inputTree,featLabels,testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
#将标签字符串转换为索引
featIndex = featLabels.index(firstStr)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else: classLabel = valueOfFeat
return classLabel
#使用pickle模块存储决策树
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'wb')
pickle.dump(inputTree,fw)
fw.close()
#使用pickle模块读取上面生成的文件
def grabTree(filename):
import pickle
fr = open(filename,'rb')
return pickle.load(fr)
测试代码:
if __name__ == "__main__":
myData,labels=createDataSet()
print(myData) #[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
print(labels) #['no surfacing', 'flippers']
print(calEntropy(myData)) #0.9709505944546686
"""
#分类越多,熵越大
myData[0][-1] = 'maybe'
print(myData) #[[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
print(calEntropy(myData)) #1.3709505944546687
print(splitDataSet(myData,0,1)) #[[1, 'yes'], [1, 'yes'], [0, 'no']]
print(chooseBestFeatureToSplit(myData)) #0,表示第0 个特征是最好的用于划分数据集的特征
myTree=createTree(myData,labels)
print(myTree) #{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
"""
myTree={'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
print(classify(myTree,labels,[1,0])) #no
print(classify(myTree,labels,[1,1])) #yes
storeTree(myTree,'classifierStorage.txt')
print(grabTree('classifierStorage.txt'))
#result:{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
#读取隐形眼镜数据
fr=open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels =['ages','prescript','astigmatic','tearRate']
lensesTree = createTree(lenses,lensesLabels)
#以字典形式输出隐形眼镜分类决策树
print(lensesTree)
"""lensesTree:
result:{'tearRate': {'normal': {'astigmatic': {'yes': {'prescript':
{'myope': 'hard', 'hyper': {'ages': {'pre': 'no lenses', 'young':
'hard', 'presbyopic': 'no lenses'}}}}, 'no': {'ages': {'pre':
'soft', 'young': 'soft', 'presbyopic': {'prescript': {'myope':
'no lenses', 'hyper': 'soft'}}}}}}, 'reduced': 'no lenses'}}
"""
3.2 使用matplotlib注解绘制树形图
代码实现:
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 15 18:41:40 2018
file name : treePlot.py
@author: lizihua
"""
import matplotlib.pyplot as plt
from tree import createTree
#使用matplotlib的注释功能绘制树形图
#用文本注解绘制树节点
#定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
#绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt,xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
#构造注解树
#获取叶节点的数目
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
#测试节点的数据是否是字典
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs +=1
return numLeafs
#获取树的层数
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
#测试节点的数据是否是字典
if type(secondDict[key]).__name__=='dict':
thisDepth = 1+getTreeDepth(secondDict[key])
else:
thisDepth =1
if thisDepth >maxDepth:
maxDepth = thisDepth
return maxDepth
#在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 +cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 +cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree, parentPt, nodeTxt):
#计算宽和高
numLeafs =getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr =list(myTree.keys())[0]
cntrPt = (plotTree.xOff +(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
#标记子节点属性值
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict =myTree[firstStr]
#减少y的偏移
plotTree.yOff = plotTree.yOff -1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff= plotTree.xOff +1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff = plotTree.yOff +1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axprops = dict(xticks = [],yticks=[])
createPlot.ax1 = plt.subplot(111,frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()
#输出预先存储的树信息
def retrieveTree(i):
listOfTree=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no','1':'yes'}},1:'no'}}}}]
return listOfTree[i]
测试代码1:
if __name__ == "__main__":
print(retrieveTree(1))
#result:{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', '1': 'yes'}}, 1: 'no'}}}}
myTree = retrieveTree(0)
print(getNumLeafs(myTree)) #3
print(getTreeDepth(myTree)) #2
createPlot(myTree)
测试结果1:
测试代码2(隐形眼镜数据):
if __name__ == "__main__":
fr=open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels =['ages','prescript','astigmatic','tearRate']
lensesTree = createTree(lenses,lensesLabels)
print(lensesTree)
createPlot(lensesTree)
测试结果2:
四、决策树的剪枝
一种简单的决策树剪枝方法: