决策树

 决策树的一个重要任务,就是为了理解数据中蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,这些机器根据数据集创建规则的过程,就是机器学习的过程。

一、确定划分数据集的决定性特征

信息增益:划分数据集前后信息发生的变化

信息:l(xi)=-log2p(xi),p(xi)是选择该分类的概率

熵(信息的期望值,表示序集无需程度的度量):H=-Σp(xi)log2p(xi)

二、划分数据集:

①计算原始数据集的熵

②将需要划分特征的数据标签抽取出来,计算剩下的数据集的熵

③将两个熵做差,算出计算得信息增益

④对每一个数据标签重复上述操作,最终算出信息增益最大得数据标签,这个数据标签就是决定性特征。

⑤对剩下来得数据标签再进行挑选分类,建立决策树。

三、存在问题

决策树可能会产生过多的数据集划分,从而产生过度匹配数据集的问题。我们可以通过裁剪决策树,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配问题

决策树和KNN算法都是谈论具有明确分类的分类算法,朴素贝叶斯分类是一定概率的分类算法。

 

import operator
import matplotlib.pyplot as plt
from math import log
#计算给定数据集的香农熵
def calcShannonEnt(dataset):
    numEntries=len(dataset)
    labelCounts={}
    for featVec in dataset:
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    shannonEnt=0.0
    for key in labelCounts:
        prob=float(labelCounts[key])/numEntries
        shannonEnt-=prob*log(prob,2)
    return shannonEnt
#mydat=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
#print(calcShannonEnt(mydat))
#按照给定特征划分数据集
def splitDataSet(dataset,axis,value):
    retDataSet=[]
    for featVec in dataset:
        if featVec[axis] == value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet
def chooseBestFeatureTosplit(dataset):
    numFeatures=len(dataset[0])-1#获取特征数目-1
    baseEntropy=calcShannonEnt(dataset)#计算原始香农熵
    bestInfoGain=0.0;bestFeature=-1
    for i in range(numFeatures):
        featList=[example[i] for example in dataset]#建立dataset中的每一列标签的值
        uniqueVals=set(featList)#唯一化
        newEntropy=0.0
        for value in uniqueVals:
            subDataSet=splitDataSet(dataset,i,value)
            prob=len(subDataSet)/float(len(dataset))
            newEntropy+=prob*calcShannonEnt(subDataSet)#计算抽取了value后的香农熵之和
        infoGain=baseEntropy-newEntropy#计算信息增益
        if(infoGain>bestInfoGain):#找到最好的划分
            bestInfoGain=infoGain
            bestFeature=i
    return bestFeature
#mydat=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
#print(chooseBestFeatureTosplit(mydat)) OUTPUT:0表明第0个特征用于划分数据集最好的特征
#返回出现最多次的分类名称
def majorityCnt(classList):
     classCount={}
     for vote in classList:
         if vote not in classCount.keys():classCount[vote]=0
         classCount[vote]+=1
     sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reversed=True)
     return sortedClassCount[0][0]
 #递归构建决策树
def createTree(dataset,labels):
     classList=[example[-1] for example in dataset]#最后一列
     if classList.count(classList[0])==len(classList):#count() 方法用于统计某个元素在列表中出现的次数,如果所有的类别都是相同的
         return classList[0]
     if len(dataset[0])==1:#只剩下一个
         return majorityCnt(classList)
     bestFeat=chooseBestFeatureTosplit(dataset)
     bestFeatLabel=labels[bestFeat]
     myTree={bestFeatLabel:{}}
     del(labels[bestFeat])
     featValues=[example[bestFeat] for example in dataset]
     uniqueVals=set(featValues)
     for value in uniqueVals:
         subLabels=labels[:]
         myTree[bestFeatLabel][value]=createTree(splitDataSet(dataset,bestFeat,value),subLabels)
     return myTree
 #绘制图形
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args=dict(arrowstyle="<-")

def plotNode(nodeText,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeText,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",
                            ha="center",bbox=nodeType,arrowprops=arrow_args)
def getNumLeafs(myTree):
     numLeafs=0
     firstStr=list(myTree.keys())[0]
     secondDict=myTree[firstStr]
     for key in secondDict.keys():
         if type(secondDict[key]).__name__=='dict':
             numLeafs+=getNumLeafs(secondDict[key])
         else:
            numLeafs+=1
     return numLeafs
def getTreeDepth(myTree):
     maxDepth=0
     firstStr=list(myTree.keys())[0]
     secondDict=myTree[firstStr]
     for key in secondDict.keys():
         if type(secondDict[key]).__name__=='dict':
             thisDepth=1+getTreeDepth(secondDict[key])
         else:  thisDepth=1
         if thisDepth>maxDepth:maxDepth=thisDepth
     return maxDepth
def plotMidText(cntrPt,parentPt,txtString):
     xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
     yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
     createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=list(myTree.keys())[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

dataset=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
labels=['no surfacing','flippers']#前两列属性的名称,dataset第三列表示的是目标变量
myTree=createTree(dataset,labels)
print(myTree)
createPlot(myTree)
Input:传入的数据分为n+1列,前n列对应n个属性(labels),最后一列是目标变量
Output:决策树

------------------------

虽然内容还没有跑出来,问题出现在import导入的matplotlib这个包,问题分析是之前下载了annacode,和自带的python冲突了,我把annacode的环境加进来,好像也不行,代码界面报错消失了,但是一运行就会报错,有博主说是因为创建的不是python package下的python文件,测试后发现仍然不行,最终!!!!!成了!!!!

解决anaconda与pycharm冲突详情见https://www.cnblogs.com/code-fun/p/12488711.html

最终!

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值