#第三章 决策树
from math import log
import operator
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei'] #显示中文标签
plt.rcParams['axes.unicode_minus']=False #这两行需要手动设置
#计算给定数据集的香农熵
#H(x)=-∑ p(xi)logp(xi)
def calcShannonEnt(dataset):
numEntries=len(dataset)
labelCounts={}
for featVec in dataset:
#为所有可能分类创建字典
currentLabel=featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
shannonEnt=0.0
for key in labelCounts:
#求p(xi)
prob=float(labelCounts[key])/numEntries
#以2为底求对数
shannonEnt-=prob*log(prob,2)
return shannonEnt
def createDataSet():
dataSet=[[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels=['no surfacing','flippers']
return dataSet,labels
#按照给定特征划分数据集
# 待划分的数据集 划分数据集的特征 需要返回的特征的值
def splitDataSet(dataSet,axis,value):
retDataSet=[]
for featVec in dataSet:
if featVec[axis]==value:
reducedFeatVec=featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataset):
numFeatures=len(dataset[0])-1
print(dataset[0])
#计算整个数据集的原始香农熵
baseEntroy=calcShannonEnt(dataset)
bestInfoGain=0.0 #初始化信息增益
bestFeature=-1 #初始化最佳分类特征
for i in range(numFeatures):
#创建唯一的分类标签列表
featList=[example[i] for example in dataset]
uniqueVals=set(featList)#从列表中创建集合是Python语言得到列表中唯一元素值的最快方法
newEntropy=0.0
#计算每种划分方式的信息熵
for value in uniqueVals: #对当前列的每一个取值进行循环
subDataSet=splitDataSet(dataset,i,value)
prob=len(subDataSet)/float(len(dataset))
newEntropy+=prob*calcShannonEnt(subDataSet) #计算当前列的信息熵
infoGain=baseEntroy-newEntropy #计算当前列的信息增益
#计算最好的信息增益
if(infoGain>bestInfoGain):
bestInfoGain=infoGain #选择最大信息增益
bestFeature=i #最大信息增益所在列的索引值
return bestFeature #返回最大信息增益所在列的索引值
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]+=1 #存储每个类标签出现的频率
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) #进行排序
return sortedClassCount[0][0] #出现次数最多的分类名称
#创建树的函数代码
# 数据集 标签列表
def createTree(dataSet,labels):
classList=[example[-1] for example in dataSet] #存储所有类标签
#第一个停止条件是所有的类标签完全相同,则直接返回该类标签
if classList.count(classList[0])==len(classList):
return classList[0]
#第二个停止条件是使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组
if len(dataSet[0])==1:
return majorityCnt(classList) #挑选出现次数最多的类别作为返回值
#创建树
bestFeat=chooseBestFeatureToSplit(dataSet)#存储最大信息增益所在列的索引值
bestFeatLabel=labels[bestFeat]
myTree={bestFeatLabel:{}}
del(labels[bestFeat])
featValues=[example[bestFeat] for example in dataSet]
uniqueVals=set(featValues)
#遍历当前选择特征包含的所有属性值
for value in uniqueVals:
subLabels=labels[:] #复制类标签,并将其保存在新列表变量subLabels中,每次调用函数createTree()时不改变原创列表的内容,使用subLabels代表原始列表
myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
#使用文本注解绘制树结点
#定义文本框和箭头格式
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")
#绘制带箭头的注解
# def plotNode(nodeTxt,centerPt,parentPt,nodeType):
# createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',
# va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
# def createPlot():
# fig=plt.figure(1,facecolor='white')
# fig.clf()
# createPlot.ax1=plt.subplot(111,frameon=False)
# plotNode(U'决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
# plotNode(U'叶节点',(0.8,0.1),(0.3,0.8),leafNode)
# plt.show()
#获取叶节点的数目
def getNumLeafs(myTree):
numLeafs=0
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
#测试节点的数据类型是否为字典
if type(secondDict[key]).__name__=='dict':
numLeafs+=getNumLeafs(secondDict[key])
else: numLeafs+=1
return numLeafs
#获取树的层数
def getTreeDepth(myTree):
maxDepth=0
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': #如果子节点是字典类型,则该节点也是一个判断节点,需要递归使用getNumLeafs()函数
thisDepth=1+getTreeDepth(secondDict[key])
else:
thisDepth=1
if thisDepth>maxDepth:
maxDepth=thisDepth
return maxDepth
#输出预先存储的树信息
def retrieveTree(i):
listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
return listOfTrees[i]
#plotTree函数
#在父子节点间填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va='center',ha='center',bbox=dict(boxstyle='round4'),arrowprops = dict(arrowstyle = '<-'))
def plotTree(myTree,parentPt,nodeTxt):
numLeafs=getNumLeafs(myTree)
depth=getTreeDepth(myTree)
firstStr=list(myTree.keys())[0]
cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,
plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=myTree[firstStr]
plotTree.yOff=plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict:
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),
cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
fig=plt.figure(1,facecolor='white')
fig.clf()
axprops=dict(xticks=[],yticks=[])
createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
plotTree.totalW=float(getNumLeafs(inTree))
plotTree.totalD=float(getTreeDepth(inTree))
plotTree.xOff=-0.5/plotTree.totalW
plotTree.yOff=1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
#决策树分类函数
def classify(inputTree,featLabels,testVec):
firstStr=list(inputTree.keys())[0]
print('inputTree.keys()',inputTree.keys())
secondDict=inputTree[firstStr]
featIndex=featLabels.index(firstStr) #使用index方法查找当前列表中第一个匹配firstStr变量的元素
for key in secondDict:
if testVec[featIndex]==key:
if type(secondDict[key]).__name__=='dict':
classLabel=classify(secondDict[key],featLabels,testVec)
else: classLabel=secondDict[key]
return classLabel
#决策树的存储
#使用pickle模块存储决策树
def storeTree(inputTree,filename):
import pickle
fw=open(filename,'wb')
pickle.dump(inputTree,fw,0)
fw.close()
def grabTree(filename):
import pickle
fr=open(filename)
return pickle.load(fr)
if __name__=='__main__':
myDat,labels=createDataSet()
# myDat[0][-1]='maybe'
# print(myDat)
# print(calcShannonEnt(myDat))
# print(splitDataSet(myDat,0,1))
# print(splitDataSet(myDat, 0, 0))
# print(chooseBestFeatureToSplit(myDat))
# print(myDat)
# myTree=createTree(myDat,labels)
# print(myTree)
myTree=retrieveTree(0)
# myTree['no surfacing'][3]='maybe'
print(myTree)
# print(getNumLeafs(myTree))
# print(getTreeDepth(myTree))
# # createPlot(myTree)
# print(classify(myTree,labels,[1,0]))
# print(classify(myTree,labels,[1,1]))
print(storeTree(myTree,'../MLinAction_source/classifierStorage.txt'))
grabTree('classifierStorage.txt')
第三章 决策树
最新推荐文章于 2024-09-09 18:34:56 发布