机器学习实战python版决策树以及Matplotlib注解绘制决策树

这一章代码比较难懂,主要是matplotlib的函数调用参数多,调用灵活,让初学者费解。

<span style="font-size:18px;"><span style="font-size:18px;">import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")#boxstyle = "swatooth"意思是注解框的边缘是波浪线型的,fc控制的注解框内的颜色深度
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")#箭头符号</span></span>
<span style="font-size:18px;"><span style="font-size:18px;">def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt#起点位置,  xycoords='axes fraction',
             xytext=centerPt#注解框位置, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )</span></span>
<span style="font-size:18px;"><span style="font-size:18px;"></span>
</span>
<span style="font-size:18px;"><span style="font-size:18px;">def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()</span></span>
<span style="font-size:18px">
</span>
<span style="font-size:18px"><img src="https://img-blog.csdn.net/20151127122340740?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQv/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center" alt="" /></span>
<span style="font-size:18px">。</span>

构造注解树

<span style="font-size:18px;">首先我们要知道树的高度和宽度,就是叶子和深度。以便我们好确定注解框的位置。</span>
<span style="font-size:18px;"><span style="font-size:18px;">def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]#父结点键
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])#递归
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth#深度加一
    return maxDepth</span></span>
<span style="font-size:18px;">通过上面的连个函数,我们就可以获得树的大小了。下面我们要绘制前面的那张图,接下来的代码就不那么简单了。</span>
<span style="font-size:18px;"><span style="font-size:18px;">def plotMidText(cntrPt, parentPt, txtString):#在中间写上分支的条件,比如前面的0,1
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on我们发现这是一层层递归下去了,直到我们找到最下面的子结点,就画上。
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
</span></span>
<span style="font-size:18px;">
上面代码理解起来比较难,我的方法就是,因为这个树比较小,我们可以跟着函数的流程走一遍,自己画一下,大概就能搞懂了,最后才发现,设置初始值真是很有技巧的才行


</pre><h2 class="html" name="code"><span style="font-size:18px">测试和存储分类器</span></h2><pre class="html" name="code"><span style="font-size:18px;">就是我们给了一系列的条件,看看根据这个分类器能不能得到我们理想的结果。</span>
<span style="font-size:18px;"><span style="font-size:18px;">def classify(inputTree,featLabels,testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict): 
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat
    return classLabel</span></span>
<span style="font-size:18px;">这也是一个递归函数,一步步找到最低端的子结点。</span>
<span style="font-size:18px;"><span style="font-size:18px;">>>> import trees
>>> myDat,labels = trees.createDataSet()
>>> labels
['no surfacing', 'flippers']
>>> myTree = treePlotter.retrieveTree(0)
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> trees.classify(myTree,labels,[1,0])
'no'
>>> trees.classify(myTree,labels,[1,1])
'yes'</span></span>
<span style="font-size:18px;">我们知道建立一个决策树是很好时间的,我们在一个大的数据集下建立的决策树费时很久,我们希望下次可以直接用而不是再生成一遍,所以我们要存储已建好的决策树。</span>
<span style="font-size:18px;"><span style="font-size:18px;">def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()
    
def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)</span></span>

示例:

<span style="font-size:18px;"><span style="font-size:18px;">
</span></span>
<span style="font-size:18px;"><span style="font-size:18px;">>>> fr = open('lenses.txt')
>>> lenses = [inst.strip().split('\t') for inst in fr.readlines()]
>>> lenses
[['young', 'myope', 'no', 'reduced', 'no lenses'], ['young', 'myope', 'no', 'normal', 'soft'], ['young', 'myope', 'yes', 'reduced', 'no lenses'], ['young', 'myope', 'yes', 'normal', 'hard'], ['young', 'hyper', 'no', 'reduced', 'no lenses'], ['young', 'hyper', 'no', 'normal', 'soft'], ['young', 'hyper', 'yes', 'reduced', 'no lenses'], ['young', 'hyper', 'yes', 'normal', 'hard'], ['pre', 'myope', 'no', 'reduced', 'no lenses'], ['pre', 'myope', 'no', 'normal', 'soft'], ['pre', 'myope', 'yes', 'reduced', 'no lenses'], ['pre', 'myope', 'yes', 'normal', 'hard'], ['pre', 'hyper', 'no', 'reduced', 'no lenses'], ['pre', 'hyper', 'no', 'normal', 'soft'], ['pre', 'hyper', 'yes', 'reduced', 'no lenses'], ['pre', 'hyper', 'yes', 'normal', 'no lenses'], ['presbyopic', 'myope', 'no', 'reduced', 'no lenses'], ['presbyopic', 'myope', 'no', 'normal', 'no lenses'], ['presbyopic', 'myope', 'yes', 'reduced', 'no lenses'], ['presbyopic', 'myope', 'yes', 'normal', 'hard'], ['presbyopic', 'hyper', 'no', 'reduced', 'no lenses'], ['presbyopic', 'hyper', 'no', 'normal', 'soft'], ['presbyopic', 'hyper', 'yes', 'reduced', 'no lenses'], ['presbyopic', 'hyper', 'yes', 'normal', 'no lenses']]
>>> import trees
>>> lensesLabels = ['age','prescript','astigmatic','tearRate']
>>> lensesTree = trees.createTree(lenses,lensesLabels)
>>> lensesTree
{'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}
>>> import treePlotter
>>> treePlotter.createPlot(lensesTree)</span></span>
<span style="font-size:18px">
</span>
<span style="font-size:18px"><img src="https://img-blog.csdn.net/20151127122636905?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQv/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center" alt="" /></span>
<span style="font-size:18px">。</span>
当然这个决策树,还有不少的缺陷,叫过度匹配,就是指匹配选项过多了,要删去一些结点和分支。还有就是我们处理的是标称型数据,就是数据是有标量名字的,而不是数值型数据,数值型数据是随意的。当然我们也可以量化数值型数据变成标称型数据。
<span style="font-size:18px;">
</span>
<span style="font-size:18px;">
</span>
<span style="font-size:18px;">希望大家多多指导。

</span>
<span style="font-size:18px;"></span>

  • 6
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值