# -- coding: utf-8 --
#from knn import*
from treePlotter import*
import matplotlib.pyplot as plt
myDat, labels = creatDataSet()
label = labels[:] #复制labels列表 防止内容被改变 label = labels是引用 使用切片方式复制
print(myDat)
print(labels)
myTree = creatTree(myDat, labels) #使用creatTree会改变labels列表内容
print(myTree)
filepath = r'E:\file\python\test\test\Tree_data\classifierstorage.txt' #在地址路径前加个r,防止反斜杠
#配眼镜决策
glasspath = r'E:\file\python\test\test\Tree_data\lenses.txt'
fr = open(glasspath)
lenses = [inst.strip().split('\t') for inst in fr.readlines()] #strip移除字符串头尾指定的字符(默认为空格)
lenseslables = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = creatTree(lenses, lenseslables)
print(lensesTree)
creatPlot(lensesTree)
# -- coding: utf-8 --
#treePlotter
from numpy import*
from math import log
import operator
import matplotlib.pyplot as plt
import pickle
def creatDataSet():
dataSet = [[1,1,'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers'] #特征标签
return dataSet, labels
#信息熵函数
def clacShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {} #创造字典
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys(): #key返回一个字典所有的键值
labelCounts[currentLabel] = 0 #创建键值对
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob, 2) #信息熵公式累加
return shannonEnt
#分割数据集
def splitDataSet(dataSet, axis, value): #axis 特征value 特征值
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0])-1 #计算特征数
baseEntropy = clacShannonEnt(dataSet) #计算数据集的信息熵
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet] #提取特征列表
uniqueVals = set(featList) #python的set是一个无序不重复元素集
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * clacShannonEnt(subDataSet) #计算条件熵
infoGain = baseEntropy - newEntropy #信息熵增益
if (infoGain > bestInfoGain): #与最大增益比较
bestInfoGain = infoGain
bestFeature = i #找出最大增益的特征
return bestFeature
def majorityCnt(classList): #多数表决
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(),
key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0]
def creatTree(dataSet, labels): #决策树生成 递归函数 label包含所有特征标签
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList): #统计字符串里某个字符出现的次数
return classList[0] #同一类 终止 返回类
if len(dataSet[0]) == 1: #表示特征用完了
return majorityCnt(classList) #使用多数表决法 终止函数返回最多的类
bestFeat = chooseBestFeatureToSplit(dataSet) #bestFeat为特征位数 0、1
bestFeatLabel = labels[bestFeat] #从特征标签中找到对应的特征名
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet] #最好特征的列表
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = creatTree(splitDataSet(dataSet, bestFeat, value), subLabels)
# bestFeatLabel和value是key 后面赋值的是value
return myTree
#获取决策树的深度和叶节点层数
def getNumleafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0] #决策树的key
secondDict = myTree[firstStr] #key对应的value 可能是个字典或者值
for key in secondDict.keys():
if type(secondDict[key]) == dict: #若是字典 表示是个子树
numLeafs += getNumleafs(secondDict[key]) #递归调用
else:
numLeafs += 1 #若是值 表示叶节点
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]) == dict: #type()函数如果你只有第一个参数则返回对象的类型,dict表示判断是否是字典类型(不能加引号)
thisDepth = 1 + getTreeDepth(secondDict[key]) #递归调用
else:
thisDepth = 1 #叶节点的情况
if thisDepth > maxDepth: #最大深度
maxDepth = thisDepth
return maxDepth
#绘制决策树
decisionNode = dict(boxstyle = "sawtooth", fc="0.8") #dict创建字典
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")
#绘制节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType): #文本注释函数 nodeTxt终点结点信息 parentPt起始地坐标 centerPT终点坐标 nodeType终点的框架类型
creatPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', #nodeTxt终点信息 xy起始坐标 xytext终点坐标
xytext = centerPt, textcoords = 'axes fraction', #bbox 结点框架类型
va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)
#绘制中间信息
def plotMidText(cntrPt, parentPt, txtString): #计算父节点和子节点中间位置 放置中间文本信息即0或1
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
creatPlot.ax1.text(xMid, yMid, txtString)
#绘制树
def plotTree(myTree, parentPt, nodeTxt): #树 父节点 节点信息
numLeafs = getNumleafs(myTree) #计算叶子节点个数
depth = getTreeDepth(myTree) #计算深度
firstStr = myTree.keys()[0] #第一个特征
cntrPt = (plotTree.x0ff + (1.0+float(numLeafs))/2.0/plotTree.totalW, plotTree.y0ff) #计算子节点位置
plotMidText(cntrPt, parentPt, nodeTxt) #计算信息位置 放置中间信息
plotNode(firstStr, cntrPt, parentPt, decisionNode) #绘制节点
secondDict = myTree[firstStr] #第一个key的value
plotTree.y0ff = plotTree.y0ff - 1.0/plotTree.totalD #纵坐标定位下降一个单位
for key in secondDict.keys():
if type(secondDict[key]) == dict :
plotTree(secondDict[key], cntrPt, str(key)) #str返回一个对象的string格式
else:
plotTree.x0ff = plotTree.x0ff + 1.0/plotTree.totalW #横坐标定位右移一个单位
plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))
plotTree.y0ff = plotTree.y0ff + 1.0/plotTree.totalD
def creatPlot(inTree):
fig = plt.figure(1, facecolor = 'white') #创建一个当前画板
fig.clf() #清理当前figure
axprops = dict(xticks = [], yticks = [])
creatPlot.ax1 = plt.subplot(111, frameon = False, **axprops) #将当前画板分为1个绘画区域(axes),111表示将画板分为1行1列,并在第一个画板绘图
# **表示接收的参数作为字典来处理
plotTree.totalW = float(getNumleafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.x0ff = -0.5/plotTree.totalW
plotTree.y0ff = 1.0
plotTree(inTree, (0.5,1.0), '') #根节点位置
plt.show()
#使用决策树
def classify(inputTree, featlabels, testVec):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
featIndex = featlabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if secondDict[key] == dict:
classlabel = classify(secondDict[key], featlabels, testVec)
else:
classlabel = secondDict[key]
return classlabel
#pickle模块存储决策树
#存储决策树
def storeTree(inputTree, filename):
fw = open(filename, 'w')
pickle.dump(inputTree, fw)
fw.close()
#取出决策树
def grabTree(filename):
fr = open(filename)
return pickle.load(fr)