python中如何画出决策树_python 实现决策树画图-Go语言中文社区

文章参考自:http://blog.csdn.net/wzmsltw/article/details/51039928

# -*- coding: utf-8 -*-

"""

Created on Wed Sep 27 14:52:51 2017

@author: Administrator

"""

import matplotlib.pyplot as plt

decisionNode=dict(boxstyle="sawtooth",fc="0.8")

leafNode=dict(boxstyle="round4",fc="0.8")

arrow_args=dict(arrowstyle="

#计算树的叶子节点数量

def getNumLeafs(myTree):

numLeafs=0

firstStr=myTree.keys()[0]

secondDict=myTree[firstStr]

for key in secondDict.keys():

if type(secondDict[key]).__name__=='dict':

numLeafs+=getNumLeafs(secondDict[key])

else: numLeafs+=1

return numLeafs

#计算树的最大深度

def getTreeDepth(myTree):

maxDepth=0

firstStr=myTree.keys()[0]

secondDict=myTree[firstStr]

for key in secondDict.keys():

if type(secondDict[key]).__name__=='dict':

thisDepth=1+getTreeDepth(secondDict[key])

else: thisDepth=1

if thisDepth>maxDepth:

maxDepth=thisDepth

return maxDepth

#画节点

def plotNode(nodeTxt,centerPt,parentPt,nodeType):

createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center", ha="center",bbox=nodeType,arrowprops=arrow_args)

#画箭头上的文字

def plotMidText(cntrPt,parentPt,txtString):

lens=len(txtString)

xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002

yMid=(parentPt[1]+cntrPt[1])/2.0

createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):

numLeafs=getNumLeafs(myTree)

depth=getTreeDepth(myTree)

firstStr=myTree.keys()[0]

cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)

plotMidText(cntrPt,parentPt,nodeTxt)

plotNode(firstStr,cntrPt,parentPt,decisionNode)

secondDict=myTree[firstStr]

plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD

for key in secondDict.keys():

if type(secondDict[key]).__name__=='dict':

plotTree(secondDict[key],cntrPt,str(key))

else:

plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW

plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)

plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))

plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD

def createPlot(inTree):

fig=plt.figure(1,facecolor='white')

fig.clf()

axprops=dict(xticks=[],yticks=[])

createPlot.ax1=plt.subplot(111,frameon=False,**axprops)

plotTree.totalW=float(getNumLeafs(inTree))

plotTree.totalD=float(getTreeDepth(inTree))

plotTree.x0ff=-0.5/plotTree.totalW

plotTree.y0ff=1.0

plotTree(inTree,(0.5,1.0),'')

plt.show()

###############测试代码

tree={'navel': {'even': 0L,

'little_sinking': {'root': {'curl_up': 0L,

'little_curl_up': {'color': {'black': {'texture': {'blur': 1L,

'distinct': 0L,

'little_blur': 1L}},

'dark_green': 1L,

'light_white': 1L}},

'stiff': 1L}},

'sinking': {'color': {'black': 1L, 'dark_green': 1L, 'light_white': 0L}}}}

createPlot(tree)

结果:

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值