决策树
算法简介
组成:长方形为判断模块,椭圆形为终止模块,
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征值数据。
缺点:可能会产生过的匹配问题
适用数据类型:数据型和标称型。
算法实现
决策树的一般流程
- 收集数据:可使用任何方法。
- 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
- 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预测。
- 训练算法:构造树的数据结构。
- 测试算法:使用经验树计算错误率。
- 使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。
利用递归算法创建决策树
创建分支伪代码函数createBranch():
检测数据集中的每个子项是否属于同于分类:
If so return 类标签
Else
寻找划分数据集的最好特征
划分数据集
创建分支节点
For每个划分的子集
调用函数createBranch并增加返回结果到分支节点中
Return 分支节点
数据划分
问题:如何采用量化的方法判断如何划分数据
原则:将无序的数据变得更加有序
判断方法:按照获取最大信息增益的方法划分数据(还可以采用基尼不纯度划分)
信息增益:划分数据集之前之后信息发生的变化
熵:集合信息的度量方式,计算方法如下所示:
符号 的信息定义:
为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值:
数据划分算法
一些决策树采用二分法进行数据划分,此处采用id3算法。
ID3算法是决策树的一种,它是基于奥卡姆剃刀原理的,即用尽量用较少的东西做更多的事。ID3算法,即Iterative Dichotomiser 3,迭代二叉树3代,是Ross Quinlan发明的一种决策树算法,这个算法的基础就是上面提到的奥卡姆剃刀原理,越是小型的决策树越优于大的决策树,尽管如此,也不总是生成最小的树型结构,而是一个启发式算法。在信息论中,期望信息越小,那么信息增益就越大,从而纯度就越高。ID3算法的核心思想就是以信息
增益来度量属性的选择,选择分裂后信息增益最大的属性进行分裂。该算法采用自顶向下的贪婪搜索遍历可能的决策空间。
进一步了解id3算法
代码实现
代码分为构造决策树、绘制决策树两个部分
构造决策树
创建名为trees.py的文件
from math import log
import operator
"""
函数calcShannonEnt(dataSet):
参数:dataSet:待处理的数据集
功能:计算给定数据集的香农熵
"""
def calcShannonEnt(dataSet):
numEntries = len(dataSet)#计算数据集中实例总数
labelCounts = {}
#遍历数据集,若当前键值不存在,则扩展字典并将当前键值加入字典,并统计各类出现次数
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries #计算p(xi)
shannonEnt -= prob * log(prob,2)#计算香农熵
return shannonEnt
"""
函数splitDataSet(dataSet, axis, value):
参数:dataSet:带划分的数据集
axis:划分数据集的特征
value:需要返回的特征的值
功能:按照给定的特征值划分数据集
"""
def splitDataSet(dataSet, axis, value):
retDataSet = []#创建新的list对象,避免函数对原数据集的修改
#将符合特征的数据抽取出来
for featVec in dataSet:
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis+1:])
retDataSet.append(reduceFeatVec)
return retDataSet
"""
函数chooseBestFeatureToSplit(dataSet):
参数:dataSet:待处理的数据集
功能:选择最好的数据集划分方式
在函数中调用的数据有一定的要求:
1.数据必须是一种由列表元素组成的列表,且所有的列表元素都要具有相同的数据长度;
2.数据的最后一列或每个实例最后一个元素是当前实例的类别标签。
"""
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1#计算数据特征总数,例:【属性A,属性B...属性N,类别】共N+1项,有N个特征
baseEntropy = calcShannonEnt(dataSet)#数据集原始香农熵
bestInfoGain = 0.0
BestFeature = -1
#遍历数据集中的所有特征
for i in range(numFeatures):
featList = [example[i] for example in dataSet]#将数据集中所有第i个特征值的可能值写入featList
uniqueVals = set(featList)#选取不重复的特征值集合,set是python中得到列表中唯一元素值的最快方法
newEntropy = 0.0
#遍历不重复的特征值集合,对每个特征值划分一次数据集计算信息增益
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
#比较所有的特征中的信息增益,返回最好划分的索引值
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
BestFeature = i
return BestFeature
"""
函数majorityCnt(classList):
参数:classList:类标签列表
功能:选出类标签中最多的一类
"""
def majorityCnt(classList):
classCount = {}
#记录各类出现的次数
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
#对各类出现次数排序
sortedClassCount = sorted(classCount.items(),
key = operator.itemgetter(1),
reverse=True)
return sortedClassCount[0][0]
"""
函数createTree(dataSet, labels):
参数:dataSet:数据集
labels:标签列表
功能:建立决策树
"""
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
#递归条件1:如果类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
#递归条件2:遍历完所有特征值时返回出现次数最多的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
#开始创建树
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}#此处采用嵌套词典的存储树
#得到列表包含的所有属性值
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]#为了不改变原始列表内容,复制类标签,将其存储在新列表变量subLabels中
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
"""
函数classify(inputTree, featLabels, testVec):
参数:inputTree:数据集
featLabels:特征标签列表
testVec:测试向量
功能:使用决策树分类
"""
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else: classLabel = secondDict[key]
return classLabel
"""
用于存储和读取决策树
"""
def storeTree(inputTree, filename):
import pickle
#fw = open(filename, 'w')
fw = open(filename, mode = 'wb')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename, 'rb')
return pickle.load(fr)
绘制决策树
创建名为treePlotter.py的文件
import matplotlib.pyplot as plt
#定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle = "<-")
#绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.axl.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction',
xytext = centerPt, textcoords = 'axes fraction',
va = "center", ha = "center", bbox = nodeType,
arrowprops = arrow_args)
"""
#此为简单的示例函数,后面会拓展为完整的绘制函数
def createPlot():
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot.axl = plt.subplot(111, frameon=False)
plotNode("a decision node", (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode("a leaf node", (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
"""
"""
函数getNumLeafs,getTreeDepth
参数:myTree:已生成的决策树
功能:获取叶节点的数目和树的层数
"""
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]#必须用list将myTree.keys()类型转换,否则为dict_key类型,无法作为list使用
secondDict = myTree[firstStr]
#遍历树的所有子节点
for key in secondDict.keys():
#若子节点类型为字典,则该节点为判断节点,需递归调用getNumLeafs
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
#绘制决策树
def plotMidText(cntrPt, parentPt, txtString):
#在父子节点间填充文本信息
xMid = (parentPt[0] - cntrPt[0])/ 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/ 2.0 + cntrPt[1]
createPlot.axl.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
#计算宽和高
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
#标记子节点属性值
plotMidText(cntrPt, parentPt, nodeTxt)
#减少y的偏移
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/ plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.axl = plt.subplot(111, frameon = False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))#存储树的宽度
plotTree.totalD = float(getTreeDepth(inTree))#存储树的深度
plotTree.xOff = -0.5/plotTree.totalW;
plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
测试代码
源码、数据集下载地址
import trees
import treePlotter
#用于测试的dataSet, labels的
def createDataSet():
dataSet = [
[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']
]
labels = ['no surfacing', 'flippers']
return dataSet, labels
#用于测试的决策树
def retrieveTree(i):
listOfTrees = [
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
"""
#测试函数calcShannonEnt
myDat, labels = createDataSet()
print(myDat)
x1 = trees.calcShannonEnt(myDat)
print(x1)#如果你的程序正确,x1=0.9709505944546686
#数据混合越多,熵越高,增加一个分类,查看熵的变化
myDat[0][-1]='maybe'
x1 = trees.calcShannonEnt(myDat)
print(x1)#如果你的程序正确,x1=1.3709505944546687
"""
"""
#方法append,extend的区别
a = [1, 2, 3]
b = [4, 5, 6]
a.append(b)#b作为一个元素加入列表a中,应得到[1, 2, 3, [4, 5, 6]]
print(a)
a = [1, 2, 3]#若不重写a,a将为[1, 2, 3, [4, 5, 6]]
a.extend(b)#得到一个包含a和b所有元素的列表,[1, 2, 3, 4, 5, 6]
print(a)
"""
"""
#测试函数splitDataSet
myDat, labels = createDataSet()
print(myDat)#[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
x1 = trees.splitDataSet(myDat, 0, 1)
x2 = trees.splitDataSet(myDat, 0, 0)
print(x1)#[[1, 'yes'], [1, 'yes'], [0, 'no']]
print(x2)#[[1, 'no'], [1, 'no']]
"""
"""
#测试函数chooseBestFeatureToSplit
myDat, labels = createDataSet()
x1 = trees.chooseBestFeatureToSplit(myDat)
print(x1)#返回值应为:0
"""
"""
#测试函数createTree
myDat, labels = createDataSet()
myTree = trees.createTree(myDat, labels)
print(myTree)
#报错:IndexError: list index out of range,代码编写有问题括号不匹配
"""
"""
#测试函数createPlot
treePlotter.createPlot()
"""
"""
#测试函数函数getNumLeafs,getTreeDepth
myTree = retrieveTree(0)
x1 = treePlotter.getNumLeafs(myTree)
x2 = treePlotter.getTreeDepth(myTree)
print(x1)#结果为3
print(x2)#结果为2
"""
"""
#测试函数createPlot
myTree = retrieveTree(0)
#treePlotter.createPlot(myTree)
myTree ['no surfacing'][3] = 'maybe'
treePlotter.createPlot(myTree)
"""
"""
#测试函数classify
myDat, labels = createDataSet()
myTree = retrieveTree(0)
print(labels)
print(myTree)
x1 = trees.classify(myTree, labels, [1,0])
x2 = trees.classify(myTree, labels, [1,1])
print(x1)#应为:no
print(x2)#应为:yes
"""
"""
#测试函数storeTre,grabTree
myTree = retrieveTree(0)
trees.storeTree(myTree, 'classifierStorage.txt')
x = trees.grabTree('classifierStorage.txt')
print(x)
"""
#示例:使用决策树预测隐形眼镜类型
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = trees.createTree(lenses, lensesLabels)
print(lensesTree)
treePlotter.createPlot(lensesTree)
示例:使用决策树预测隐形眼镜类型结果为
总结
- 决策树存在过度匹配的问题,为了解决这个问题需要对决策树进行修剪,相关问题会在后续学习过程中进行学习;
- 我使用的构造算法为ID3算法,无法直接处理数据型,还有其他决策树构造算法,如:C4.5和CART,之后会去进一步了解这些算法。