Use matplotlib draw the tree

标签: DecisionTree
35人阅读 评论(0) 收藏 举报
分类:

Plotting the tree in Python with Matplotlib annotations

Unfortunately,Python does’t include a good tool for plotting trees. so we’ll make our own.

这才是真正的工程师精神

Matplotlib has a great tool ,called annotations.that can add text near data in a plot.

1.Plotting trees nodes with text annotations

利用文字注释功能来画树结点

import matplotlib.pyplot as plt

# define nodeType 叶结点,判别结点,箭头类型的定义

decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")

# 定义结点函数

def plotNode(nodeText,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeText,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',
                           va='center',ha='center',bbox=nodeType,arrowprops=arrow_args)
    # 这个参数多的有点恐怖.没有看懂

# 创造一个绘制图

def createPlot():
    fig=plt.figure(1,facecolor='white')
    fig.clf()                 #create a new figure and clear it  将新建的画板进行清理 
    createPlot.ax1=plt.subplot(111,frameon=False)
    plotNode('a decision node ',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()
# 调用函数,将绘制的图在屏幕上显示

createPlot()

[createplot

2.A strategy for plotting tree

Identifying the number of leaves in a tree and the depth

Need to know how many leafnodes and how many levels to decide the properly size the X,Y direction.

# Numleafs function
def getNumLeafs(myTree):
    numLeafs=0
    firstList=list(myTree.keys())
    firstStr=firstList[0]
    secondDict=myTree[firstStr]# 读取键值的value
    for key in secondDict.keys():# 监测是否有还有字典集合
        if type(secondDict[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict[key])
        else: numLeafs+=1
    return numLeafs

# depths function
def getTreeDepth(myTree):
    maxDepth=0
    firstList=list(myTree.keys())
    firstStr=firstList[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else: thisDepth=1
        if thisDepth>maxDepth:
            maxDepth=thisDepth
    return maxDepth
需要注意的是这里有 Python版本的问题
Python 2 中 firstStr=myTree.keys()[0]
Python 3 中 firstList=list(firstStr)
             firstStr=firstList[0]
           这个函数的目的是将字典的第一个键值进行读取,
# make a tree data
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]
retrieveTree(0)
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
#运行函数查看
getNumLeafs(retrieveTree(1))
4
# Plots text between child and parent
def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)

# define the main functions, plotTree
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstList = list(myTree.keys())
    firstStr=firstList[0] #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict
# 进行图形显示
def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW
    plotTree.yOff=1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()
# 调用函数进行来完成树的绘制
createPlot(retrieveTree(0))

plotTree

createPlot(retrieveTree(1))

png
otherTree

3.Put our decision tree code to use on some real data

# classification function for an existing decision tree

def classify(inputTree,featLabels,testVec):
    firstList=list(inputTree.keys())
    firstStr=firstList[0]
    secondDict=inputTree[firstStr]
    featIndex=featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex]==key:
            if type(secondDict[key]).__name__=='dict':
                classLabel=classify(secondDict[key],featLabels,testVec)
            else:
                classLabel=secondDict[key]
    return classLabels
利用pickle 来进行序列化 serialize objects allow us to store them for later use

def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()
def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

4.persisting the decision tree

# methods for persisting the decision tree with pickle
def storeTree(inPutTree,filename):
    import pickle
    fw=open(filename,'w')
    pickle.dump(inPutTree,fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr=open(filename)
    return pickle.load(fr)

Summary

最主要的还是掌握C4.5 和CART 算法的过程,详细见西瓜书,周志华.还有就是剪枝处理,连续值余缺失值的处理.
代码的实现过程,只是一个将理论转换成实际的过程,我觉得代码可以用的可以直接 import.
查看评论

演示matplotlib中,如何在坐标系中画一个矩形

演示matplotlib中,如何在坐标系中画一个矩形 import matplotlib.pyplot as plt from matplotlib.patches import Rectang...
  • rumswell
  • rumswell
  • 2012-09-23 17:24:29
  • 4166

mac 安装matplotlib后执行程序报错

执行程序后报错:Traceback (most recent call last): File "/Users/bojie.hbj/my_project/XXXX/XXX.py", line 4,...
  • bojie5744
  • bojie5744
  • 2017-02-07 21:25:48
  • 684

java.lang.RuntimeException: Canvas: trying to use a recycled bitmap

错误提示如题 具体错误内容如下: 08-10 12:21:00.370: E/InputEventReceiver(16613): Exception dispatching input event....
  • u011290399
  • u011290399
  • 2013-08-10 12:28:51
  • 2710

matplotlib Artist 教程

1 总括1.1 简介在matplotlib API中有三个图层:matplotlib.backend_bases.FigureCanvas(画板)、matplotlib.backend_bases.R...
  • claroja
  • claroja
  • 2017-04-28 11:15:34
  • 1129

matplotlib使用记录

工作中需要使用python的matplotlib包作图,记录了一些需要注意的坑。 1.Linux服务器没有GUI的情况下使用matplotlib绘图import matplotlib as mpl m...
  • jacklin929
  • jacklin929
  • 2017-04-01 11:17:19
  • 927

Python3在virtualenv环境下使用matplotlib绘图遇到的问题

在virtualenv环境下使用matplotlib绘图时遇到了这样的问题: RuntimeError: Python is not installed as a framework. The Ma...
  • ocean20
  • ocean20
  • 2018-02-13 16:10:50
  • 104

VB6画PieChart

  • 2014年04月07日 09:49
  • 2KB
  • 下载

UML 2 Use Case Diagrams

Use case diagrams depict: Use cases. A use case describes a sequence of actions that provide some...
  • eddle
  • eddle
  • 2011-12-03 19:43:54
  • 2730

java.lang.RuntimeException: Canvas: trying to use a recycled bitmap android.graphics.Bitmap@412d723

最近遇到了如标题这样的错误,再次记录解决方法。本文参考帖子: http://bbs.csdn.net/topics/390196217 出现此bug的原因是在内存回收上,里面用Bita...
  • wds1181977
  • wds1181977
  • 2016-03-02 10:08:57
  • 446

Hrbust 2310 Tree Painting(欧拉路径性质)

Tree Painting Time Limit: 1000 MS Memory Limit: 131072 K Total Submit: 42...
  • NEET_Champloo
  • NEET_Champloo
  • 2017-08-05 15:58:23
  • 164
    个人资料
    等级:
    访问量: 105
    积分: 87
    排名: 259万+
    文章存档