from math import log
import operator
defcreateDataSet():
dataSet =[[1,1,'maybe'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
labels =['no surfacing','flippers']#change to discrete valuesreturn dataSet, labels
#信息熵计算公式#H=-(p1·logp1+p2·logp2+…p32·logp32)#信息量越大,信息熵越高defcalcShannonEnt(dataSet):
numEntries =len(dataSet)
labelCounts ={}for featVec in dataSet:#the the number of unique elements and their occurance
currentLabel = featVec[-1]if currentLabel notin labelCounts.keys(): labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
shannonEnt =0.0for key in labelCounts:
prob =float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2)#log base 2return shannonEnt
#返回dataSet中,第axis项值为value的数据defsplitDataSet(dataSet, axis, value):
retDataSet =[]for featVec in dataSet:if featVec[axis]== value:#将axis列特征剔除
reducedFeatVec = featVec[:axis]#chop out axis used for splitting
reducedFeatVec.extend(featVec[axis+1:])#剩余列特征返回
retDataSet.append(reducedFeatVec)return retDataSet
defchooseBestFeatureToSplit(dataSet):
numFeatures =len(dataSet[0])-1#the last column is used for the labels
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain =0.0; bestFeature =-1for i inrange(numFeatures):#iterate over all the features#example为dataSet的一个样本#example[i]为其对应的第i个特征#即创建唯一的分类标签列表
featList =[example[i]for example in dataSet]#create a list of all the examples of this feature#uniqueVals是一个枚举所有属性的set
uniqueVals =set(featList)#get a set of unique values
newEntropy =0.0#计算每种划分方式的熵for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob =len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)#信息增益为分类前的信息熵减去分类后的信息熵#信息熵就是信息的期望值,信息熵越越小,信息的纯度越高,也就是信息越少#在分类领域来讲就是里面包含的类别越少,所以我们可以得出,与初始信息熵的差越大分类效果越好。
infoGain = baseEntropy - newEntropy #calculate the info gain; ie reduction in entropyif(infoGain > bestInfoGain):#compare this to the best gain so far
bestInfoGain = infoGain #if better than current best, set to best
bestFeature = i
return bestFeature #returns an integer#返回出现次数最多的classdefmajorityCnt(classList):
classCount={}for vote in classList:if vote notin classCount.keys(): classCount[vote]=0
classCount[vote]+=1
sortedClassCount =sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]defcreateTree(dataSet,labels):#将dataSet中的数据按行依次放入example中,然后取得example中的example[i]元素,放入列表featList中
classList =[example[-1]for example in dataSet]#if classList.count(classList[0])==len(classList):return classList[0]#stop splitting when all of the classes are equaliflen(dataSet[0])==1:#stop splitting when there are no more features in dataSetprint(classList)return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree ={bestFeatLabel:{}}del(labels[bestFeat])
featValues =[example[bestFeat]for example in dataSet]
uniqueVals =set(featValues)for value in uniqueVals:
subLabels = labels[:]#copy all of labels, so trees don't mess up existing labels
myTree[bestFeatLabel][value]= createTree(splitDataSet(dataSet, bestFeat, value),subLabels)return myTree
defclassify(inputTree,featLabels,testVec):print(inputTree)#取字典inputTree中的所有key,组成firstList
firstList=list(inputTree.keys())#左子树的name
firstStr=firstList[0]#firstStr = inputTree.keys()[0]#右子树的所有数据
secondDict = inputTree[firstStr]#查找左子树的name在特征labels的位置(数组下标)
featIndex = featLabels.index(firstStr)
key = testVec[featIndex]
valueOfFeat = secondDict[key]ifisinstance(valueOfFeat,dict):
classLabel = classify(valueOfFeat, featLabels, testVec)else: classLabel = valueOfFeat
return classLabel
import b
myDat,labels=createDataSet()#ret=chooseBestFeatureToSplit(myDat)#myTree=createTree(myDat,labels)
myTree=b.retrieveTree(0)#print(myTree)
ret=classify(myTree,labels,[1,0])print(ret)
ret=classify(myTree,labels,[1,1])print(ret)
import matplotlib.pyplot as plt
#定义文本框和箭头格式
decisionNode =dict(boxstyle="sawtooth", fc="0.8")
leafNode =dict(boxstyle="round4", fc="0.8")
arrow_args =dict(arrowstyle="<-")#绘制注释defplotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )# #绘制箭头Demo# def createPlot():# fig = plt.figure(1, facecolor='white')# fig.clf()# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses # plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)# plt.show()# createPlot()#返回树的叶数defgetNumLeafs(myTree):
numLeafs =0#python3需要先将dict转成list
firstList=list(myTree.keys())
firstStr = firstList[0]
secondDict = myTree[firstStr]for key in secondDict.keys():iftype(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
numLeafs += getNumLeafs(secondDict[key])else: numLeafs +=1return numLeafs
#返回树的深度defgetTreeDepth(myTree):
maxDepth =0
firstList=list(myTree.keys())
firstStr = firstList[0]
secondDict = myTree[firstStr]for key in secondDict.keys():iftype(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
thisDepth =1+ getTreeDepth(secondDict[key])else: thisDepth =1if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
#创建树demodefretrieveTree(i):
listOfTrees =[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]return listOfTrees[i]defplotMidText(cntrPt, parentPt, txtString):
xMid =(parentPt[0]-cntrPt[0])/2.0+ cntrPt[0]
yMid =(parentPt[1]-cntrPt[1])/2.0+ cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)defplotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree)#this determines the x width of this tree
depth = getTreeDepth(myTree)
firstList=list(myTree.keys())
firstStr = firstList[0]#firstStr = myTree.keys()[0] #the text label for this node should be this
cntrPt =(plotTree.xOff +(1.0+float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff -1.0/plotTree.totalD
for key in secondDict.keys():iftype(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key],cntrPt,str(key))#recursionelse:#it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff +1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt,str(key))
plotTree.yOff = plotTree.yOff +1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dictdefcreatePlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops =dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False,**axprops)#no ticks#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW =float(getNumLeafs(inTree))
plotTree.totalD =float(getTreeDepth(inTree))
plotTree.xOff =-0.5/plotTree.totalW; plotTree.yOff =1.0;
plotTree(inTree,(0.5,1.0),'')
plt.show()# myTree=retrieveTree(0)# myTree['no surfacing'][3]='maybe'# # print(myTree)# # ret=getNumLeafs(myTree)# # print(ret)# # ret=getTreeDepth(myTree)# # print(ret) # ret=createPlot(myTree)