目录
一、决策树相关概念介绍
1、什么是决策树/判定是?
判定树是一个类似于流程图的树结构:其中,每个内部结点表示在一个属性上的测试,每个分支代表一个属性输出,而每个树叶结点代表类或类分布。树的最顶层是根结点。
2、决策树优缺点
决策树的优点:直观,便于理解,小规模数据集有效
决策树的缺点: 处理连续变量不好; 类别较多时,错误增加的比较快;可规模性一般
3、熵概念
一条信息的信息量大小和它的不确定性有直接的关系,要搞清楚一件非常非常不确定的事情,或者是我们一无所知的事情,需要了解大量信息==>信息量的度量就等于不确定性的多少
二、决策树归纳算法
选择属性判断结点
信息获取量(Information Gain):Gain(A) = Info(D) - Infor_A(D)
通过A来作为节点分类获取了多少信息
1、举个栗子
2、算法
3、树剪枝叶(避免overfitting)
(1)先剪枝:比如比例达到多少后就不考虑分支了
(2)后剪枝:树建立完成后如果太大了再考虑剪枝
三、代码实现
1、创建数据集
#创建数据集
def createDataSet():
dataSet = [[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels = ['no surfacing','flippers']
return dataSet,labels
2、计算给定数据集的香农熵
#计算给定数据集的香农熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet) #计算数据集中实例总数
labelCounts = {} #存储标签
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys(): #如果当前键不存在,就扩展字典并将当前键值加入字典
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 #记录出现的总次数--频率
shannonEnt = 0.0 #熵值初始化为0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2) #以2为底求对数
return shannonEnt
3、按照给定特征划分数据集
#按照给定特征划分数据集
#dataSet:待划分的数据集
#axis:划分数据集的特征
#value:特征的返回值
def splitDataSet(dataSet,axis,value):
#注意:python语言在函数中传递的是列表的引用,在函数内部对列表对象进行修改,将会影响该列表对象的整个生存周期
#为了消除这个不良的影响,我们需要在函数开始声明一个新列表对象
retDataSet=[]
for featVec in dataSet:
#当我们按照某个特征划分数据集时,就需要将所有符合要求的元素抽取出来
if featVec[axis]==value: #将符合特征的数据抽取出来
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
4、选择最好的数据集划分方式
#选择最好的数据集划分方式
#该函数实现选取特征,划分数据集,计算得出最好的划分数据集的特征
def chooseBestFeatureToSpit(dataSet):
numFeatures = len(dataSet[0]) - 1 #表示特征值的列数
baseEntropy = calcShannonEnt(dataSet) #计算整个数据集的原始香农熵
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet] #创建唯一的分类标签列表,将数据集中蓑鲉第i个特征值或者所有可能存在的值写入这个新的list
uniqueVals = set(featList)#集合和列表的区别是可以最快得到列表中唯一元素
newEntropy = 0.0
#计算每种划分方式的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
#计算最好的信息增益
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
#返回最好特征划分的索引值,从0开始
return bestFeature
5、测试
import tree
import treePlotter
#创建数据集
dataSet,labels=tree.createDataSet()
print(dataSet)
#熵的计算函数
shannonEnt=tree.calcShannonEnt(dataSet)
print(shannonEnt)
#划分数据集
retDataSet=tree.splitDataSet(dataSet, 0, 1) #表示选择第一列值为1的
print(retDataSet)
#retDataSet=tree1.splitDataSet(dataSet,0,0) #表示选择第一列值为0的
#print(retDataSet)
6、多数表决
#多数表决
#如果数据集已经处理了所有属性,但是类标签依然不是唯一的此时我们需要决定如何定义该叶子结点,在这种情况下,我们通常会采用多数表决的方式决定该叶子结点的分类
#classList:标签列表
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]+=1
sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
7、创建决策树
#创建决策树
#dataSet:数据集
#labels:标签列表-包含数据集中所有特征的标签
def createTree(dataSet,labels):
classList=[example[-1] for example in dataSet] #classList包含数据集的所有类标签
if classList.count(classList[0]) == len(classList):#递归函数停止的第一个条件:所有类标签完全相同,则直接返回该类标签
return classList[0]
if len(dataSet[0]) == 1 :#递归函数停止的第二个条件:使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组
return majorityCnt(classList) #使用多数表决方法,挑选出现次数最多的类别作为返回值
bestFeat=chooseBestFeatureToSpit(dataSet)#选择最好特征划分的索引值
bestFeatLabel=labels[bestFeat]
myTree={bestFeatLabel:{}}#创建树,字典类型存储树信息
del(labels[bestFeat]) #此处,标签列表要随着子集变化而变化
featValues=[example[bestFeat] for example in dataSet]#找出决策节点后,继续深入分析特征值
uniqueVals=set(featValues)#找出最佳分类的特征值的所有可能取值存储
for value in uniqueVals:
subLabels=labels[:]#复制类标签,此时已分类好的特征值被删除了
#此时数据集和类标签都已更新,再递归调用
myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
8、判定数据属于哪个分类
#使用决策树的分类函数
#具体判断数据属于哪个分类
#inputTree:构造好的树--{'flippers': {0: 'no', 1: {'no surfacing': {0: 'no', 1: 'yes'}}}}
#featLabels:分类标签
#testVec:需要测试的数据,看它属于哪个分类
def classify(inputTree,featLabels,testVec):
firstStr = list(inputTree.keys())[0] #flippers 表示第一个分类的标签
secondDict = inputTree[firstStr]#{0: 'no', 1: {'no surfacing': {0: 'no', 1: 'yes'}}} 表示第一个分类标签对应的列表
featIndex = featLabels.index(firstStr)#1 表示第一个分类的标签flippers在featLabels中的索引
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
9、测试
myTree=tree.createTree(dataSet, labels)
print(myTree)
#使用决策树的分类函数
classLabel=tree.classify(myTree,labels,[1,0])
print(classLabel)
10、使用pickle模块存储/读取决策树
#使用pickle模块存储决策树
def storeTree(inputTree,filename):
import pickle
fw = open(filename, "wb")
pickle.dump(inputTree, fw)
fw.close()
#使用pickle模块读取决策树
def grabTree(filename):
import pickle
fr=open(filename,'rb')
return pickle.load(fr)
myStoreTree=tree.grabTree('classifierStorage.txt')
print(myStoreTree)
四、绘制决策树
1、需引入matplotlib
#绘制树
import matplotlib as mpl
mpl.use('TkAgg')
import matplotlib.pyplot as plt
2、获取叶子节点数目和树的层数
#获取叶节点的数目
#myTree格式:{'flippers': {0: 'no', 1: {'no surfacing': {0: 'no', 1: 'yes'}}}}
def getNumLeafs(myTree):
numLeafs=0
#注意:你将结果传递somedict.keys()给函数。在Python 3中,dict.keys不返回一个列表,而是一个表示库键和视图(类似于set)的类集对象,不支持索引。
#要解决该问题,请使用收集密钥并使用list(somedict.keys())密钥。
firstStr=list(myTree.keys())[0] #第一个关键字是第一次划分数据集的类标签:flippers
secondDict=myTree[firstStr]#根据第一个关键字获取下层目录
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#判断子结点是否为字典类型
numLeafs+=getNumLeafs(secondDict[key])#如果是字典类型,则该节点也是一个判断节点,需递归调用
else:#如果不是字典类型,则说明是叶子节点
numLeafs+=1
return numLeafs
#获取数的层数
def getTreeDepth(myTree):
maxDepth=0
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth=1+getTreeDepth(secondDict[key])
else:
thisDepth=1
if thisDepth>maxDepth:
maxDepth=thisDepth
return maxDepth
3、定义一些样式
#这个是用来一注释形式绘制节点和箭头线,可以不用管
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
#这个是用来绘制线上的标注,简单
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
4、决定树的绘制(逻辑绘制)
# 重点,递归,决定整个树图的绘制,难(自己认为)
def plotTree(myTree, parentPt, nodeTxt): # if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) # this determines the x width of this tree
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0] # the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key], cntrPt, str(key)) # recursion
else: # it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
5、实际绘制树
# 这个是真正的绘制,上边是逻辑的绘制
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False) # no ticks
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW;
plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
6、创建数据集
#这个是用来创建数据集即决策树
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}},3:'maybe'}}
]
return listOfTrees[i]
7、测试
import treePlotter
treePlotter.createPlot(treePlotter.retrieveTree(2))