决策树的一般流程:
- 收集数据:anymethd
- 准备数据:树构造算法只适用于标称型数据,因此数值型必须离散化
- 分析数据:可以使用任何方法,树构造完成后应该检查图形是否符合预期
- 训练算法:构造树的数据结构
- 测试算法:使用经验树计算错误率
- 使用算法:适用于任何监督学习算法,而使用决策树可更好的理解数据的内在含义
计算给定数据集的香农熵(集合信息的度量方式,度量数据集的无序程度)
熵定义为信息的期望值,熵越高,混合的数据也越多。
def calcShannonEnt(dataSet):
numEntries=len(dataSet)
labelCounts={}
for featVec in dataSet:#统计所有类标签发生的次数
currentLabel=featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
shannoEnt=0.0
for key in labelCounts:
prob=float(labelCounts[key])/numEntries
shannoEnt-=prob*log(prob,2)
return shannoEnt
def createDataSet():#创建简单数据集
dataSet=[[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels=['no surfacing','flippers']
return dataSet,labels
划分数据集(度量划分数据集的熵,判断当前是否正确划分了数据集)
对每个特征划分数据集的结果计算一次信息熵,然后判断哪个特征是划分数据集的最好划分方式。
#splitDataSet(待划分的数据集、划分数据集的特征、特征的返回值)
def splitDataSet(dataset,axis,value):
retDataSet=[]
for featVec in dataset:
if featVec[axis]==value:
#判断特征值为axis的列,其值是否等于value,
#splitDataSet(dataset,0,0)---即判断featVec[0]是否等于0
reducedFeatVec=featVec[:axis]
print(reducedFeatVec)
reducedFeatVec.extend(featVec[axis+1:])
print(reducedFeatVec)
retDataSet.append(reducedFeatVec)
print(retDataSet)
return retDataSet
m,l=createDataSet()
calcShannonEnt(m)
splitDataSet(m,0,0)
[]
[1, 'no']
[[1, 'no']]
[]
[1, 'no']
[[1, 'no'], [1, 'no']]
#-----------------------------------------------------------------------------------------------
#遍历整个数据集,循环计算熵和splitDataSet(),找到最好的特征划分方式。
def chooseBestFeatureToSplit(dataSet):
numFeatures=len(dataSet[0])-1#计算特征值数
baseEntropy=calcShannonEnt(dataSet)
bestInfoGain=0.0
bestFeature=-1#最好的特征值
for i in range(numFeatures):
featList=[example[i] for example in dataSet]#列表推导 i遍历dataSet,并且将每行的第i列存放到example
uniqueVals=set(featList)#Build an unordered collection of unique elements.
newEntropy=0.0
for value in uniqueVals:
subDataSet=splitDataSet(dataSet,i,value)
prob=len(subDataSet)/float(len(dataSet))#选择该分类的概率
newEntropy+=prob*calcShannonEnt(subDataSet)#新的熵值
infoGain=baseEntropy-newEntropy
if(infoGain>bestInfoGain):
bestInfoGain=infoGain
bestFeature=i
return bestFeature#返回最好的特征值
递归构建决策树
工作原理:得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于2个分支的数据集划分。
采用递归处理:结束条件为(遍历完所有划分数据集的属性 or 每个分支下的所有实例都具有相同的分类)
def majorityCnt(classList):#返回次数最多的分类名称
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount+=1
sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):#所有的类标签完全相同,则直接返回该类标签
return classList[0]
if len(dataSet[0]) == 1:#遍历完所有特征仍然不能把数据集划分成仅包含唯一类别的分组,则返回出现次数最多的返回
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:] #copy all of labels,不改变原始列表的内容
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
使用Matplotlib注解绘制树形图
1.Matplotlib注解
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
#import matplotlib.pyplot as plt
#plt.annotate()文本注释
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
# =============================================================================
#subplot(row,col,plotNum)
# 将绘图区域分为row*col列子域,并且按照从左往右,上到下对每个
# 子区域编号,若row,col,plotNum都小于10,可用3位数字之间代替
# 在plotNum的区域中创建轴对象。
#plot(*args, **kwargs)
# plot(x, y) # plot x and y using default line style and color
# plot(x, y, 'bo') # plot x and y using blue circle markers
#figure(num=None, figsize=None, dpi=None, facecolor=None,
# edgecolor=None, frameon=True, FigureClass=<class 'matplotlib.figure.Figure'>,
# **kwargs)
#Creates a new figure
#clf()
# Clear the current figure.
#createPlot.ax1
# import dis
# def func():
# func().ax1=123
# dis.dis(func)
# 一切皆对象,对函数也可以添加属性
# =============================================================================
def createPlot():
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
2.构造注解树
def getNumLeafs(myTree):#获取叶节点
numLeafs=0
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
numLeafs+=getNumLeafs(secondDict[key])
else:
numLeafs+=1
return numLeafs
def getTreeDepth(myTree):#获取树的层数
maxDepth=0
# python 2.x D.keys()->list
# python 3.x D.keys(...)-> a set-like object providing a view on D's keys
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth=1+getTreeDepth(secondDict[key])
else:
thisDepth=1;
if thisDepth>maxDepth:
maxDepth=thisDepth
return maxDepth
def retrieveTree(i):#输出预先存储的树信息,主要用于测试
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
myTree=retrieveTree(0)
getNumLeafs(myTree)# 3
getNumTreeDepth(myTree)# 2
------------------------以下将前面所学组合一起,绘制一颗完整树-----------------------------
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) #this determines the x width of this tree depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] #the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops=dict(xticks=[],yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False,**axprops) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW;
plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops=dict(xticks=[],yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False,**axprops) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()
测试和存储分类器
#1.测试算法:创建使用决策树的分类器
def classify(inputTree,featLabels,testVec):#使用决策树的分类函数
firstStr=list(inputTree.keys())[0]
secondDict=inputTree[firstStr]
featIndex=featLabels.index(firstStr)#将标签字符串转换为索引
for key in secondDict.keys():
if testVec[featIndex]==key:#比较testVec变量中的值与树节点的值
if type(secondDict[key]).__name__=='dict':
classLabel=classify(secondDict[key],featLabels,testVec)
else:
classLabel=secondDict[key]
return classLabel
# =============================================================================
# In[]: m,l=trees.createDataSet()
# Out[]: ['no surfacing', 'flippers']
#
# In[]: myTree=treePlotter.retrieveTree(0)
#
# In[]: myTree
# Out[]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
# In[]: trees.classify(myTree,l,[1,0])
# Out[]: 'no'
# =============================================================================
#2.使用算法:决策树的存储
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
。