之前我们已经了解了从数据集构造决策树的各种子功能模块,原理:从原始数据中基于最好的特征值进行划分数据集,由于特征值可能多余两个,所以可能存在大于两个分支的数据集划分。第一次划分之后数据将被传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。因此我们可以用递归的原则处理数据。
递归结束的条件是:程序遍历完所有划分数据集的属性,或则每个分支下的所有实例都具有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或终止模块。任何到达叶子节点的数据必然属于叶子节点的分类。
第一个条件可以使算法终止,我们在算法开始之前计算列的数目,查看算法是否使用了所有属性即可。如果数据集已经处理了所有属性,但是类标签不唯一,此时我们需要决定如何定义该叶子节点,这种情况下通常使用多数表决的方法决定叶子节点的分类。
def majorityCnt(classList):
#存储每类标签出现的频率,按照从小到大的顺序进行排列
classCount = {}
for vote in classList:
if vote not in classList.keys():
classCount[vote] += 1
classCount[vote] += 1
sortedclassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedclassCount
创建树的函数代码:
def createTree(dataSet, labels):
'''
输入参数:数据集和标签列表
算法本身并不需要这个变量,为了给出数据明确的含义,我们将他作为一个输入参数提供给
'''
#所有数据集的类标签
classList = [example[-1] for example in dataSet]
#第一个停止条件:所有类标签完全相同,直接返回该类标签
if classList.count(classList[0]) == len(classList):
return classList[0]
#使用完了所有特征仍不能将数据集划分成一类,返回出现次数最多的类别作为返回值
if len(dataSet[0]) == 1:
return majorityCnt(classList)
#当前数据集中选取的最好的特征值 返回值为特征值的索引
bestFeat = chooseBestFeatureToSplit(dataSet)
print(bestFeat)
#最好特征的特征名称
bestFeatLabel = labels[bestFeat]
print(bestFeatLabel)
myTree = {bestFeatLabel:{}}
print(myTree)
#分类结束后删除当前特征
del(labels[bestFeat])
#遍历所有样本集(数据集)中的最好特征对应的特征值
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
#特征对应的特征值
for value in uniqueVals:
'''
当函数参数是列表类型时,参数是按照引用方式传递的。
为了保证每次调用函数createTree()时不改变原始列表的内容,使用新变量代替原始列表
'''
#复制了所有的特征名称(这里是我觉得和书不一样的地方)
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
print(myTree)
return myTree
执行上述语句,我们可以的到如下的结果:
变量myTree包含了很多代表树结构的嵌套字典。从左边开始,第一个关键字:no surfacing是第一个划分数据集的特征名字,该关键字的值也是另一个数据字典。第二个关键字是no surfacing特征划分的数据集,这些关键字的值都是no surfacing节点的子节点。这些值可以是类标签,也可以是另一个数据字典。如果值是类标签,则该节点是叶子节点,如果值是另一个数据字典,那么子节点是一个判断节点。这种不断重复的结构就构成了整棵树。
下面我们将使用matplotlib注解来绘制树形图:
1、matplotlib注解
import matplotlib.pyplot as plt
#定义树节点格式的常量
#文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8") #boxstyle="sawtooth"边框线是波浪线 fc注解框的颜色深度
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
#执行绘图功能
#绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
#xy=parentPt起点位置, #xytext=centerPt注解框位置
#创建一个新的绘图区,在上面绘制两个不同类型的树节点
def createPlot():
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon=False)
plotNode('a decision node', (0.5,0.1), (0.1,0.5), decisionNode)
plotNode('a leaf node', (0.8,0.1), (0.3,0.8), leafNode)
plt.show()
我们掌握了绘制绘制树节点的方法后,下面将学习如何绘制整棵树。
2、构造注解树
我们需要知道有多少个叶节点,以便正确的确定x轴的长度;我们还需要知道有多少层,以便确定y的高度。我们通过定义两个函数来确定叶节点的数目和树的层数。
我们使用如下两个函数来获取叶节点的数目和树的层数。
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
#python3与python2的区别,python3要把键变成一个列表
#print(firstStr) secondDict = myTree[firstStr] #print(secondDict) for key in secondDict.keys(): #print(key) if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumleafs(secondDict[key]) else: numLeafs += 1 return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1+getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:maxDepth = thisDepth
return maxDepth
我们用一个函数预先输出存储树的信息:
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers':
{0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
我们现在将之前所学的方法组合在一起,绘制一棵完整的树:
def plotMidText(cntrPt, parentPt, txtString):
#在父子节点间填充文本信息
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
'''
cntrPt用来记录当前要画的树的树根的结点位置
'''
#计算宽高
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0] #python3与python2的区别,python3要把键变成一个列表
'''
我们希望树根在这些所有叶子节点的中间位置
这里的 1.0 + numLeafs 需要拆开来理解,也就是
plotTree.xOff + float(numLeafs)/2.0/plotTree.totalW +1.0/2.0/plotTree.totalW
plotTree.xOff + 1/2 * float(numLeafs)/plotTree.totalW + 0.5/plotTree.totalW
因为xOff的初始值是-0.5/plotTree.totalW ,是往左偏了0.5/plotTree.tatalW 的,这里正好加回去。
这样cntrPt记录的x坐标正好是所有叶子结点的中心点
'''
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
#标记子节点的属性
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
#减少y偏移
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key],cntrPt,str(key)) #recursion
else: #it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
#yOff的初始值为1,每向下递归一次,这个值减去 1 / totalD
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
'''
xOff和yOff用来记录当前要画的叶子结点的位置。
画布的范围x轴和y轴都是0到1,我们希望所有的叶子结点平均分布在x轴上。
totalW记录叶子结点的个数,那么 1/totalW 正好是每个叶子结点的宽度。
如果叶子结点的坐标是 1/totalW , 2/totalW, 3/totalW, …, 1 的话,就正好在宽度的最右边,
为了让坐标在宽度的中间,需要减去0.5 / totalW 。初始化 plotTree.xOff 的值为-0.5/plotTree.totalW。
这样每次 xOff + 1/totalW ,正好是下1个结点的准确位置
'''
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()
输出效果如下:
。
以上就是关于从原始数据中创建决策树并用python库来绘制树形图。