机器学习实战—决策树(二)

#-*-coding:utf-8-*-
import ch
ch.set_ch()
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle = "sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4",fc = "0.8")
arrow_args = dict(arrowstyle = "<-")

#建立标注annotate
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
                            # 标注内容   标注位置                                    标签位置
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords = 'axes fraction',xytext = centerPt,textcoords='axes fraction',va="center",\
                                         #标签的格式         箭头的格式
                            ha = "center",bbox = nodeType,arrowprops=dict(arrowstyle="<-"))

def createPlotTemp():
                    #图名,可以是数字  背景颜色              
    fig = plt.figure("xihuan",facecolor = 'white')
    fig.clf()#clear the figure
    createPlotTemp.ax1 = plt.subplot(111,frameon = False)#产生一个子图,不显示坐标轴,但有坐标
    plotNode(U'决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode(U'叶节点',  (0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

#计算决策树的叶子节点的数目
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key])==dict:
            numLeafs += getNumLeafs(secondDict[key])
        else: numLeafs += 1
    return numLeafs
#计算树的深度
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key])==dict:
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else: thisDepth = 1;
        if thisDepth > maxDepth:maxDepth = thisDepth
    return maxDepth
#生成一棵决策树
def retrieveTree(i):
    listOfTrees = [{'no surfacing': {0:'no',1:{'flippers':\
                   {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0:'no',1:{'flippers':\
                   {0:{'head':{0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                   ]
    return listOfTrees[i]

def plotMidText(centrPt,parentPt,txtString):
    xMid = (parentPt[0]-centrPt[0])/2.0+centrPt[0]
    yMid = (parentPt[1]-centrPt[1])/2.0+centrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)


def createPlot(inTree):
    fig = plt.figure("xihuan",facecolor = 'white')
    fig.clf()
    axprops = dict(xticks=[],yticks=[])
    createPlot.ax1 = plt.subplot(111,frameon = False,**axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xoff = -0.5/plotTree.totalW;plotTree.yoff = 1.0;#为了保证根结点标注于标签位置一致
    plotTree(inTree,(0.5,1.0),'')
    plt.show()
    
def plotTree(myTree,parentPt,nodeText):
    numleafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0] #第一个分类特征
    centrPt = (plotTree.xoff + (1.0+float(numleafs))/2.0/plotTree.totalW,\
               plotTree.yoff)
    plotMidText(centrPt,parentPt,nodeText)#显示文本标签信息,根节点为空
    plotNode(firstStr,centrPt,parentPt,decisionNode)#打印标注特征信息
    secondDict = myTree[firstStr]
    plotTree.yoff = plotTree.yoff-1.0/plotTree.totalD#调整下一个子数的Y方向位置
    for key in secondDict.keys():
        if type(secondDict[key])==dict:
            plotTree(secondDict[key],centrPt,str(key))
        else:#画出结点即可
            plotTree.xoff = plotTree.xoff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xoff,plotTree.yoff),centrPt,leafNode)
            plotMidText((plotTree.xoff,plotTree.yoff),centrPt,str(key))
    plotTree.yoff = plotTree.yoff+1.0/plotTree.totalD#由于递归返回上一层,所以这里返回上层的y分量高度
            
    
    
mytree = retrieveTree(1)

#print getTreeDepth(mytree)
createPlot(mytree)

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值