<pre name="code" class="python">#coding:utf-8
from math import log
import operator
#计算dataSet的熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
lableCounts = {}
#首先要找到所有的分类结果,以及每个结果出现的次数
for featVec in dataSet:
currentLabel = featVec[-1]
if(currentLabel not in lableCounts.keys()):
lableCounts[currentLabel] = 0
lableCounts[currentLabel] += 1
shannonEnt = 0
#根据各个结果出现的总次数求dataSet的熵
for key in lableCounts:
prob = lableCounts[key] / float(numEntries)
shannonEnt -= prob*log(prob,2)
return shannonEnt
def createDataSet():
dataSet = [[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels = ['no surfacing', 'flippers']
return dataSet,labels
#在同一个axis中,不同的结果会有不同的value
#这个函数的作用是讲第axis列中等于value的数据集划分出来
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if(featVec[axis] == value):
reducedfeatVec = featVec[:axis]
#将返回的dataSet中将不会出现axis这一列
reducedfeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedfeatVec)
return retDataSet
#选择最好的数据集划分方式
#指的是从当前的dataSet中选取一个列作为特征划分时,此时划分后dataSet的熵最小
def chooseBestFitureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEnt = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
for i in range(numFeatures):
#创建唯一的分类标签列表
featList = [exam[i] for exam in dataSet]
uniqueVals = set(featList)
ent = 0.0
#计算每种划分方式的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet) / float(len(dataSet))
ent += prob * calcShannonEnt(subDataSet)
infoGain = baseEnt - ent
#选取最好的信息熵
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
#返回出现次数最多的分类名称
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classList.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), \
key=operator.itemgetter(1),reverse=True)
return classCount[0][0]
#通过递归的方式创建树
def createTree(dataSet, labels):
#类别完全相同则停止继续划分
classList = [exam[-1] for exam in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
#遍历完所有特征时,若特征不唯一,则返回出现次数最多的特征
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFitureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
#得到列表中除了选取label外其他的特征
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [exam[bestFeat] for exam in dataSet]
uniqueVals = set(featValues)
#根据剩余的特征继续创建树
for value in uniqueVals:
subLabel = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(\
dataSet,bestFeat,value),subLabel)
return myTree
#coding:utf-8
import matplotlib.pyplot as plt
import pdb
decision_node = dict(boxstyle="sawtooth", fc="0.8")
leaf_node = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
#绘制带箭头的注解
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
xytext=centerPt,textcoords='axes fraction',\
va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
#通过递归方式获得一颗树的叶子节点数目
def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
seconDict = myTree[firstStr]
for key in seconDict.keys():
if type(seconDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(seconDict[key])
else:
numLeafs += 1
return numLeafs
#通过递归方式获得一颗树的深度
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
seconDict = myTree[firstStr]
for key in seconDict.keys():
if type(seconDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(seconDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def retrive_tree(i):
list_of_trees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return list_of_trees[i]
#计算父子节点之间的中间位置
def plotMidText(cntrPt, parentPt,txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
'''
plotTree.xOff yOff追踪的是已经绘制的节点位置
totalD totalW保存整个树的深度和高度
cntrPt是子节点
'''
def plotTree(myTree,parentPt,nodeTxt):
#计算宽度与高度
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = myTree.keys()[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,\
plotTree.yOff)
plotMidText(cntrPt,parentPt,firstStr)
plotNode(firstStr,cntrPt,parentPt,decision_node)
seconDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in seconDict.keys():
if type(seconDict[key]).__name__ == 'dict':
plotTree(seconDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(seconDict[key], (plotTree.xOff,plotTree.yOff), \
cntrPt,leaf_node)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt,str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
'''
根据决策树,画出这个决策树
'''
def createPlot(inTree):
pdb.set_trace()
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[],yticks=[])
createPlot.ax1 = plt.subplot(111,frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()
'''
inputTree是决策树
featLabels是所有的标签
testVec是测试数据
classify(myTree, labels, [1,0])
'''
def classify(inputTree,featLabels,testVec):
firstStr = inputTree.keys()[0]
seconDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in seconDict.keys():
if testVec[featIndex] == key:
if type(seconDict[key]).__name__ == 'dict':
classLabel = classify(seconDict[key], featLabels, testVec)
else:
classLabel = seconDict[key]
return classLabel
#使用pickle模块存储树
def storeStree(inputTree, filename):
import pickle
fw = open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()
def grabStree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)