决策树预测学生成绩模型
前言
决策树主要用于分类问题。导入学生成绩,用0-1预测学生是否通过考试;
主要是放下 treePlotter 模块的代码;
一、决策树预测代码
之前的怎么又失效了。。。提取码: o54j
二、treePlotter模块
自定义treePloter模块,命名为treePloter,新建Python文件__init__,复制代码,
放在父文件夹下:
i# _*_ coding: UTF-8 _*_
import matplotlib.pyplot as plt
"""绘决策树的函数"""
decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 定义分支点的样式
leafNode = dict(boxstyle="round4", fc="0.8") # 定义叶节点的样式
arrow_args = dict(arrowstyle="<-") # 定义箭头标识样式
# 计算树的叶子节点数量
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
# 计算树的最大深度
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
# 画出节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
xytext=centerPt, textcoords='axes fraction', va="center", ha="center", \
bbox=nodeType, arrowprops=arrow_args)
# 标箭头上的文字
def plotMidText(cntrPt, parentPt, txtString):
lens = len(txtString)
xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002
yMid = (parentPt[1] + cntrPt[1]) / 2.0
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.x0ff + \
(1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW
plotNode(secondDict[key], \
(plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
plotMidText((plotTree.x0ff, plotTree.y0ff) \
, cntrPt, str(key))
plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.x0ff = -0.5 / plotTree.totalW
plotTree.y0ff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
if __name__== '__main__':
createPlot()