决策树 算法原理及代码

   决策树可以使用不熟悉的数据集合,并从中提取出一系列的规则,这是机器根据数据集创建规则的过程,就是机器学习的过程。用一个小案例分析:

 

通过No surfacing  和 flippers判断该生物是否是鱼,No surfacing 是离开水面是否可以生存,flippers判断是否有脚蹼

引入信息增益和信息熵的概念:

信息熵:计算熵,我们需要计算所有类别所有可能值包含的信息期望值。

                                        p(x)是类别出现的概率

条件熵(表示在已知随机变量X的条件下随机变量Y的不确定性。):

                                       

信息增益(划分数据集前后的信息发生的变化,通俗的说,就是信息熵减去条件熵):

                                         

代码实现:

      加载数据:

def createDataSet():
    dataSet = [
        [1,1,'yes'],
        [1,1,'yes'],
        [1,0,'no'],
        [0,1,'no'],
        [0,1,'no']
        
    ]
    labels = ['no surfacing','flippers']
    return dataSet,labels

计算原始熵:

def calcShannonEnt(dataSet):
    numEntries = len( dataSet)
    labelCounts = { }
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if  currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts [currentLabel]+=1
    shannonEnt = 0.0 
    for key in labelCounts :
        prob = float (labelCounts[key])/numEntries
        shannonEnt -=prob * log(prob,2)
    return shannonEnt 
划分数据集
def splitDataSet(dataSet,axis,value):  # 待划分的数据集  ,划分数据集的特征,需要返回的特征的值
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis] == value :
            reduceFeatVec=featVec[:axis]  #取不到axis这一行
            reduceFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reduceFeatVec)
    return retDataSet

测试数据及结果:


(myDat,0,1)  myDat是数据集,0是第一次划分数据集,1是第一列为1的数据

计算出条件熵,然后求出信息增益,并找到最大的信息增益,最大的信息增益就是找到最好的划分数据集的特征

def chooseBestFeatureToSplit(dataSet):
    numFeatures=len(dataSet[0])-1
    #计算出原始的香农熵
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature =-1
    for i in range (numFeatures):
        #创建唯一的分类标签列表
        featList = [example[i] for example in dataSet]
        uniqueVals = set (featList)  #去重复
        #条件熵的初始化
        newEntropy = 0.0
        for value in uniqueVals :
            #划分   获得数据集
            subDataSet = splitDataSet(dataSet,i ,value)
            prob=len(subDataSet)/float(len(dataSet)) # 概率
            #条件熵的计算
            newEntropy += prob * calcShannonEnt (subDataSet)
        # 信息增益
        infoGain = baseEntropy -newEntropy
        if (infoGain >bestInfoGain):
            bestInfoGain = infoGain #找到最大的信息增益
            bestFeature =i  #找出最好的划分数据集的特征
        return bestFeature

测试数据:

dataSet,labels = createDataSet()
print(dataSet)
print(chooseBestFeatureToSplit(dataSet))

输入结果:

投票机制:


def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys() :
            classCount[vote]=0
    sortedClassCount = sorted (classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0] 

创建树:

def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0] )== len (classList) :
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree={bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set( featValues)
    for value in uniqueVals:
        subLabels =labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree

结果:


该方法是用信息增益的方法来构建树,在查阅其他的博客得知:

    ID3算法主要是通过信息增益的大小来判定,最大信息增益的特征就是当前节点,这个算法存在许多的不足,第一,它解决不了过拟合问题,和缺失值的处理,第二,信息增益偏向取值较多的特征,第三,不能处理连续特征问题。

因此,引入C4.5算法,是利用信息增益率来代替信息增益。为了减少过度匹配问题,我们通过剪枝来处理冗余的数据,生成决策树时决定是否要剪枝叫预剪枝,生成树之后进行交叉验证的叫后剪枝。

还有一个是引入基尼指数来进行计算叫CART树,以后再做介绍。

绘制树形图:

decisionNode = dict(boxstyle = "sawtooth", fc="0.8")
leafNode = dict(boxstyle = "round4" ,fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction' ,\
                           xytext=centerPt ,textcoords='axes fraction',va="center" ,\
                            ha ="center" ,bbox=nodeType,arrowprops = arrow_args)

def getNumLeafs(myTree):
    numLeafs= 0
    firstStr =list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs +=getNumLeafs(secondDict[key])
        else :  numLeafs+=1
    return numLeafs
def getTreeDepth(myTree) :
    maxDepth=0
    firstStr =list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1+ getTreeDepth(secondDict[key])
        else : thisDepth = 1
        if thisDepth > maxDepth : maxDepth=thisDepth
    return maxDepth
def plotMidText(cntrPt , parentPt ,txtString) :
    xMid = (parentPt[0]-cntrPt[0])/2.0 +cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 +cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff +1.0 /plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff) ,cntrPt,str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def createPlot(inTree) :
    fig = plt.figure(1,facecolor = 'white')
    fig.clf()
    axprops =dict(xticks=[],yticks=[])
    createPlot.ax1 = plt.subplot(111,frameon = False ,**axprops)
    plotTree.totalW =float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW;plotTree.yOff = 1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()
createPlot(myTree)





  • 3
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

樱缘之梦

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值