一只菜鸡的决策树入门(三)

      画树完成之后,发现数据形式的树很难直观的看出其中的分类决策,尤其当参与分类的特征较多的情况下。个人工作是做信贷风控决策,碰到的特征变量数不胜数,简直头疼,书上选用matplotlib库的注解功能尝试绘制属性图,这个库之前入门python的时候就了解过,参数实在是太多,比较烦,用来试试。

      以文本注解的形式绘制树的节点:

from matplotlib import pyplot as plt

# 节点和箭头的样式确定一下
decisionNode = dict(boxstyle='sawtooth',fc='0.8')
leafNode = dict(boxstyle='round4',fc='0.8')
arrow_args = dict(arrowstyle='<-')

# 绘制节点函数
def plotNode(nodetext,cntrPt,parentPt,nodeType):
    # ax1为createPlot属性,调用函数annotate()绘制节点
    createPlot.ax1.annotate(nodetext,xy=parentPt,xycoords='axes fraction',
                            xytext=cntrPt,textcoords='axes fraction',
                                va='center',ha='center',bbox=nodeType,
                                    arrowprops=arrow_args)

# 绘制一下节点试试

def createPlot():
    fig = plt.figure(1,facecolor='white')
    fig.clf() # 清空绘图区
    # 全局变量createPlot.ax1定义绘图区域(变量默认均为全局有效)
    createPlot.ax1 = plt.subplot(111,frameon=False) # 去除表框轴
    plotNode('决策节点',(0.5,0.8),(0.1,0.5),decisionNode)
    plotNode('叶节点',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()


然后递归函数计算叶子结点数量以及树的层数:

# 计算叶子结点数量
def calcLeafs(myTree):
    numofLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numofLeafs += calcLeafs(secondDict[key])
        else:
            numofLeafs += 1
    return numofLeafs

# 计算树的层数(经过几次特征划分)
def calcDepth(myTree):
    numofDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + calcDepth(secondDict[key])
        else:
            thsiDepth = 1
        if thisDepth > numofDepth:
            numofDepth = thisDepth
    return numofDepth
            
            

      为最后绘制工作增加绘制父子节点间文本的部分:

def plotMidtext(cntrPt,parentPt,textstring):
    xmid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    ymid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xmid,ymid,textstring)

      createPlot才是真正的大哥,我们的主函数,其他都是小弟,其中的还有一个小弟是负责画树的,把前面的小弟集合一下:

def plotTree(myTree,parentPt,nodeTxt):
    numLeafs = calcLeafs(myTree)
    depth = calcDepth(myTree)
    firstStr = list(myTree.keys())[0]
    SecondDict = myTree[firstStr]
    # 当前绘制的节点坐标(x轴按照该层叶子结点数量均分,当前坐标为中点)
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs) / 2.0 / plotTree.totalW,
                                        plotTree.yOff)
    # 绘制第一层,当前分类的节点
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    # y轴按照层数depth均分,每层绘制完则递减
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    # 校正?
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
    
    
    
    

      有一个看了很久看不懂的,当前节点坐标cntrPt,书中根据该层叶子节点数量均分X轴,当前结点位置位于中间float(numLeafs)/2.0/plotTree.totalW,在初始化plotTree.xOff值的时候取-0.5/plotTree.totalW(叶子结点位于x中间),即整体向左挪动了0.5个x,所以需要将以上向左偏移的0.5个x补回来,叶子结点x坐标则按照每次一个x的距离增加。图是按照比例绘制的。

      接着用createPlot绘制展示:

 def createPlot(inTree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks[],yticks[]) # 刻度轴设置为空
    createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW = float(calcLeafs(inTree))
    plotTree.totalD = float(calcDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0 # 初始值为1.0,自上而下绘制且y递减
    plotTree(inTree,(0.5,1.0),'')
    plt.show()
    

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值