from math import log
import operator
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
def calcshannonent(dataSet):
#返回数据集行数
numEntries = len(dataSet)
#保存每个标签(label)出现次数的字典
labelCounts = {}
#对每组特征向量进行统计
for featVec in dataSet:
currentLabel = featVec[-1] #提取标签信息
if currentLabel not in labelCounts.keys(): #如果标签没有放入统计次数的字典,添加进去
labelCounts[currentLabel]=0
labelCounts[currentLabel] += 1 #label计数
shannonEnt=0.0 #经验熵
#计算经验熵
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries #选择该标签的概率
shannonEnt -= prob*log(prob,2) #利用公式计算
return shannonEnt
def splitdataset(dataset, axis, value):
"""
把数据集按照特征值进行拆分
axis:一个数,表示第几个特征
value: 一个指定的值
"""
retDataSet = []
#遍历每一条数据
for featVec in dataset:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
#extend添加的是列表内的元素而不是列表本身
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def choosebestfeaturetosplit(dataset):
"""
循环选择最好的用于切分数据的特征
"""
numfeatures = len(dataset[0]) - 1 #特征数量为第0个列表的长度-1
baseentropy = calcshannonent(dataset) #基本准则为香农熵
bestinfogain = 0.0
bestfeature = -1 #首先最好的信息增益为0,最好的特征为-1
#得到第i个特征的value集合
for i in range(numfeatures): #遍历特征数
featlist = [example[i] for example in dataset] #把该特征的所有值取出来
uniquevals = set(featlist) #分类
newentropy = 0.0
#
for value in uniquevals:
subdataset = splitdataset(dataset, i, value) #按照特征值进行切分
prob = len(subdataset) / float(len(dataset)) #计算每个值的比例
newentropy += prob * calcshannonent(subdataset) #计算条件经验熵
infogain = baseentropy - newentropy #计算信息增益
print('第%d个特征的信息增益为%.2f' % (i,infogain))
if (infogain > bestinfogain): #循环选择最大的信息增益
bestinfogain = infogain
bestfeature = i
return bestfeature
def majoritycnt(classList):
"""
统计classList中出现次数最多的元素(类标签)
Parameters:
classList:类标签列表
Returns:
sortedClassCount[0][0]:出现次数最多的元素(类标签)
"""
classCount={}
#统计classList中每个元素出现的次数
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
#根据字典的值降序排列
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def createtree(dataSet,labels):
"""
函数说明:创建决策树
Parameters:
dataSet:训练数据集
labels:分类属性标签
featLabels:存储选择的最优特征标签
Returns:
myTree:决策树
"""
classList=[example[-1] for example in dataSet]
#两个停止条件
if classList.count(classList[0]) == len(classList): #完全相同时,只有一个值,所以用classlist[0]
return classList[0]
if len(dataSet[0]) == 1: #dataset每一条数据是一个列表,第0条数据的长度为1指的是只剩下一个特征可用
return majoritycnt(classList) #虽然只有一个特征,但不一定只有一个分类。 选择数目最多的分类
#拿到最优特征
bestFeat = choosebestfeaturetosplit(dataSet) #每一轮选择的最好的特征是信息增益最大的特征
bestFeatLabel = labels[bestFeat] #从labels列表里取第i个(labels指的是每个特征的名字)
#featLabels.append(bestFeatLabel) #
myTree = {bestFeatLabel:{}}
del(labels[bestFeat]) #从labels列表里删除
print(labels)
featValues = [example[bestFeat] for example in dataSet]
uniqueVls = set(featValues)
#遍历所有属性值,在每个值上创建子树
for value in uniqueVls:
sublabels = labels[:]
myTree[bestFeatLabel][value] = createtree(splitdataset(dataSet,bestFeat,value),sublabels)
return myTree
def getnumleafs(mytree):
"""
获取叶子的数量
"""
numleafs = 0
firststr = next(iter(mytree))
seconddict = mytree[firststr]
for key in seconddict.keys():
if type(seconddict[key]).__name__ == 'dict':
numleafs += getnumleafs(seconddict[key])
else:
numleafs += 1
return numleafs
def gettreedepth(mytree):
"""
获取树的深度
"""
maxdepth = 0
firststr = next(iter(mytree))
seconddict = mytree[firststr]
for key in seconddict.keys():
if type(seconddict[key]).__name__ == 'dict':
thisdepth = 1 + gettreedepth(seconddict[key])
else:
thisdepth = 1
if thisdepth > maxdepth: maxdepth = thisdepth
return maxdepth
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
"""
画图的基本参数
nodetxt: 方框内的文字
centerpt: 文本中心点
parentpt: 指向文本的点(箭头的尾部)
nodetype: 节点的样式
"""
#箭头格式
arrow_args = dict(arrowstyle="<-")
#设置中文字体
font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14)
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)
def plotMidText(cntrPt, parentPt, txtString):
"""
计算标注位置(标注:特征的某个值)
cntpt : 文本的中心点
parentpt: 指向文本的点
txtstring:标注的内容
"""
#当前节点的
#xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] #计算标注位置
#yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
xMid = (parentPt[0]+cntrPt[0])/2.0
yMid = (parentPt[1]+cntrPt[1])/2.0
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):
"""
绘制树图
mytree:传进去的决策树(字典型)(每一轮的决策树是不同的,第二轮往后都是子树)
parentpt:
nodetxt:节点文本
"""
#设置判断结点格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
#设置叶子节点格式
leafNode = dict(boxstyle="round4", fc="0.8")
#获取当前树叶子数目
numLeafs = getnumleafs(myTree)
#print(numLeafs)
#判断节点的文本(某个特征)
firstStr = next(iter(myTree))
#节点的中心位置
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
#添加标注
plotMidText(cntrPt, parentPt, nodeTxt) #因此第一个箭头上不需要文本,传入‘ ’,从第二层箭头开始,传入的都是子树的键的值。
#绘制判断节点
plotNode(firstStr, cntrPt, parentPt, decisionNode) #这里是画了一个指向该节点的箭头(最开始指向tear rate这个节点)
#下一个字典
secondDict = myTree[firstStr]
#下一个y的
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #总的树深为totalD,yoff又为1,因此每次下降1/4是正确的
#print('1:%s' % plotTree.yOff)
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
#是判断节点的,递归画树,
#下一个树的parentpt = 本次判断节点的centerpt(也就是说,下一个节点箭头尾部在这个节点的中心)
#下一颗树的根节点文本为键的值
plotTree(secondDict[key],cntrPt,str(key))
else:
#是叶节点的,直接确定其坐标,画叶节点
#确定X的偏移量,第一步中相当于第一个叶节点的位置为0.5/9 =0.055,即左侧
#可以通过调整偏移量调整叶节点的位置
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
#叶节点的文本是该键的值,文本中心为x,y,箭头尾部为上一轮的节点,类型为叶子型
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
#标注的位置,标注的文本为键
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
#该语句是为了保证每一棵子树的各个节点高度一致
#该语句只有在前面进行到全部都是叶节点(即不存在递归的时候,才会执行)
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#print('2:%s' % plotTree.yOff)
def createPlot(inTree):
"""
画图函数(主体函数)
intree: 传入的决策树
"""
#创建fig
fig = plt.figure(1, facecolor='white')
#清空fig
fig.clf()
#去掉x、y轴
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
#获取叶结点数目
plotTree.totalW = float(getnumleafs(inTree)) #9
#获取树高
plotTree.totalD = float(gettreedepth(inTree))
#print(plotTree.totalD) #4
#设定xoff和yoff的初始值
plotTree.xOff = -0.5/plotTree.totalW
#print(plotTree.xOff)
plotTree.yOff = 1.0
#传入绘制树图函数的参数,包括决策树,根节点的坐标和根节点的文本
plotTree(inTree, (0.5,1.0), '')
plt.show()
if __name__ == '__main__':
fr = open('lensess.txt')
lenses = []
for row in fr.readlines():
curline = row.strip().split('\t')
lenses.append(curline)
lenseslabels = ['age','prescript', 'astigmatic', 'tear_rate']
#print(lenses)
mytree = createtree(lenses,lenseslabels)
print(mytree)
#print(mytree)
createPlot(mytree)
关于绘图函数plottree的一点说明:
该函数用的是递归算法,在第一个循环里,进行的是方框1中的操作,即绘制第一个判断节点 tear rate。本来按照循环,这里应该有一个箭头,但是第一个循环中centerpt 和parentpt的坐标都是(0.5,1)。相当于箭头的头和尾在同一个点上,因此不显示。
在以后的每个循环中,绘制的图形都是从上一层节点到本节点的箭头,注释以及本层节点。