决策树
#-*-coding:utf-8 -*- from math import log import operator import matplotlib.pyplot as plt decisionNode = dict(boxstyle="sawtooth", fc="0.8")#定义文本框和箭头格式 leafNode = dict(boxstyle="round4", fc="0.8") arrow_args = dict(arrowstyle="<-") def plotNode(nodeTxt, centerPt, parentPt, nodeType):#绘制带箭头的注释 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) def plotMidText(cntrPt, parentPt, txtString):#计算父节点与子节点的中间位置,在父子节点间填充文本信息 xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree)#计算宽 depth = getTreeDepth(myTree)#计算深度 firstStr = myTree.keys()[0] cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt)#标记子节点属性值 plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD#减少Y的偏移 for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #if you do get a dictonary you know it's a tree, and the first element will be another dict def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.show() def getNumLeafs(myTree):#计算叶节点的个数 numLeafs = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs +=1 return numLeafs def getTreeDepth(myTree):#计算树的深度 maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth def majorityCnt(classList): classCount={} for vote in classList: if vote not in classCount.keys():classCount[vote]=0 classCount[vote]+=1 sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0] def calcShannonEnt(dataSet): numEntries=len(dataSet) labelCounts={} for featVec in dataSet: currentLabel=featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel]=0 labelCounts[currentLabel]+=1 shannonEnt=0.0 for key in labelCounts: prob=float(labelCounts[key])/numEntries#该类别的概率 shannonEnt-=prob*log(prob,2)#计算熵 return shannonEnt def splitDataSet(dataSet,axis,value):#带划分的数据集,划分数据集的特征,需要返回的特征的值 retDataSet=[]#创建新的list对象 for featVec in dataSet:#遍历全部数据集 if featVec[axis]==value:#符合特征 reduceFeatVec=featVec[:axis] reduceFeatVec.extend(featVec[axis+1:]) retDataSet.append(reduceFeatVec) return retDataSet def chooseBestFeatureToSPLIT(dataSet): numFeatures=len(dataSet[0])-1 baseEntropy=calcShannonEnt(dataSet)#计算熵 bestInfoGain=0.0;bestFeature=-1 for i in range(numFeatures):##遍历数据集中的所有特征 featList=[example[i] for example in dataSet]#创建新的列表,将数据集中的所有第i个特征值写入list中 uniqueVals=set(featList) newEntropy=0.0 for value in uniqueVals:#计算每种划分方式的信息熵 subDataSet=splitDataSet(dataSet,i,value) prob=len(subDataSet)/float(len(dataSet)) newEntropy+=prob*calcShannonEnt(subDataSet) infoGain=baseEntropy-newEntropy if(infoGain>bestInfoGain): bestInfoGain=infoGain bestFeature=i return bestFeature def createTree(dataSet,labels):#输入数据集与标签列表 classList=[example[-1] for example in dataSet]#创建列表变量包含数据集的所有类的标签 if classList.count(classList[0])==len(classList):#如果所有的类别表情完全相同,则直接返回该类标签 return classList[0] if len(dataSet[0])==1:#遍历所有特征,返回出现次数最多的 return majorityCnt(classList) bestFeat=chooseBestFeatureToSPLIT(dataSet)#最好的特征 bestFeatLabel=labels[bestFeat] myTree={bestFeatLabel:{}} del(labels[bestFeat]) featValues=[example[bestFeat] for example in dataSet] uniqueVals=set(featValues) for value in uniqueVals: subLabels=labels[:]#复制了类标签 myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels) return myTree ''' def createDataSet(): dataSet=[[1,1,'YES'], [1,1,'YES'], [1,0,'NO'], [0,1,'NO'], [0,1,'NO']] labels = ['no surfacing', 'flippers'] return dataSet,labels def retrieveTree(i): listOfTree=[{'no surfacing': {0: 'NO', 1: {'flippers': {0: 'NO', 1: 'YES'}}}},\ {'no surfacing': {0: 'NO', 1: {'flippers': {0:{'head':{ 0:'NO', 1: 'YES'}},1:'no'}}}}] return listOfTree[i] print retrieveTree(1) myTree=retrieveTree(0) myTree['no surfacing'][3]='maybe' print myTree createPlot(myTree) ''' def classify(inputTree,featLabels,testVec): firstStr=inputTree.keys()[0] secondDict=inputTree[firstStr] featIndex=featLabels.index(firstStr) for key in secondDict.keys(): if testVec[featIndex]==key: if type(secondDict[key])._name_=='dict': classLabel=classify(secondDict[key],featLabels,testVec) else: classLabel=secondDict[key] return classLabel fr=open('F:\python\lenses.txt') lenses=[inst.strip().split('\t') for inst in fr.readlines()] lensesLabels=['age','prescript','astigmatic','tearRate'] lenseTree=createTree(lenses,lensesLabels) print lenseTree createPlot(lenseTree)