决策树算法学习笔记(三)-预测隐形眼镜类型

#coding=utf-8
import matplotlib.pyplot as plt
#定义文本框和箭头格式
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")


def getNumLeafs(myTree):
    numLeafs=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict[key])
        else:
            numLeafs+=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth=0
    firstStr=myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else:
            thisDepth=1
        if thisDepth>maxDepth:maxDepth=thisDepth
    return maxDepth

def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.axl.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):
     numLeafs=getNumLeafs(myTree)
     depth=getTreeDepth(myTree)
     firstStr=myTree.keys()[0]
     cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
     plotMidText(cntrPt,parentPt,nodeTxt)
     plotNode(firstStr,cntrPt,parentPt,decisionNode)
     secondDict=myTree[firstStr]
     plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
     for key in secondDict.keys():
         if type(secondDict[key]).__name__=='dict':
             plotTree(secondDict[key],cntrPt,str(key))
         else:
             plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
             plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
             plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
     plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD


def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.axl=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW;
    plotTree.yOff=1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()


def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.axl.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',\
                            va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)

# def createPlot():
#     fig=plt.figure(1,facecolor='white')
#     fig.clf()
#     createPlot.axl=plt.subplot(111,frameon=False)
#     plotNode('决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
#     plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
#     plt.show()

def retrieveTree(i):
    listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},\
                 { 'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
                 ]
    return listOfTrees[i]
# myTree=retrieveTree(0)
# myTree['no surfacing'][3]='maybe'
# print myTree
# createPlot(myTree)
# print retrieveTree(0)
# print  getNumLeafs(myTree=retrieveTree(0))
# print getTreeDepth(myTree=retrieveTree(0))

def fileReading(filename):
    fr=open(filename)
    dataSet=fr.readlines()
    lenses=[inst.strip().split('\t') for inst in dataSet]
    lensesLabels=['age','prescript','astigmatic','tearRate']
    return  lenses,lensesLabels

dataSet,labels=fileReading('lenses.txt')
lensesTree=createTree(dataSet,labels)
print  lensesTree
treePlotter.createPlot(lensesTree)



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值