机器学习实战+第三章_决策树

 本章采用的是ID3算法。通过计算香农熵来确定最佳特征(bestFeature),再通过最佳特征将树划分成子树,递归的调用createTree函数。

优缺点:

可视化非常好

无法处理数值型数据

可能出现过度匹配的现象,可以通过剪枝了缓解

注明:

书里面的代码时用的python2,我是3.6版本,所以有些地方会有出入。

其次,关于treePlotter模块,xOff和偏移量的公式我进行了修正,使它更易于理解。


from math import log
import operator

'''
#决策树_ID3算法
选择香农熵最小的属性划分数据集(如果该属性下labels一致或再无属性可以划分则返回(最多的)label作为该分支决策的结果)
将该属性(bestfeature)作为字典的键,字典的值是一个(子)字典。
这个(子)字典的键由该属性的值的集合组成,(子)字典的值为递归的使用createTree函数处理该属性的值所有对应的(删除了bestfeature的)dataset返回的结果。
'''

def createDataSet():
	dataset=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
	attrilist=['no surfacing','flippers']		#'no surfacing'为dataset第一列属性的label,'flippers'为dataset第二列属性的label
	return dataset,attrilist



#计算香农熵(Shannon Entropy),熵越高则混合的数据越多
def calShannonEnt(dataset):
	row=len(dataset)
	countdict={}
	for vector in dataset:
		label=vector[-1]
		countdict[label]=countdict.get(label,0)+1
	shannonent=0
	for key in countdict:
		prob=countdict[key]/row
		shannonent-=prob*log(prob,2)
	return shannonent



#按照给定特征划分数据集
def splitDataSet(dataset,axis,value):	#axis 用于划分数据集的特征,为dataset的列下标		#value 属性值
	retDataSet=[]
	for vector in dataset:
		if vector[axis]==value:
			reducedvector=vector[:axis]
			reducedvector.extend(vector[axis+1:])
			retDataSet.append(reducedvector)
	return retDataSet
'''
#在函数中传递的是列表的引用,所以在函数内部对dataset的修改会影响到该列表对象的整个生存期。为了避免这一影响,我们重新构造了一个retDataSet.

#extend()与append():
a=[1,2,3]
b=[4,5,6]
a.append(b)
print(a)	#[1,2,3,[4,5,6]]
a=[1,2,3]
a.extend(b)	
print(a)	#[1,2,3,4,5,6]
'''	



#选择最好的属性来划分:对dataset的每个属性都计算香农熵并比较信息增益。选择信息增益最大的一项,返回其属性的index。
#Feature特征   Attribut属性 两者的含义是相同的
def chooseBestFeatureToSplit(dataset):
	numFeature=len(dataset[0])-1		#dataset的属性数,为dataset的列数减一,因为labels项占据了dataset的一列
	baseEnt=calShannonEnt(dataset)	#原始的香农熵
	bestInfoGain=0					#最大信息增益
	bestFeature=-1
	for i in range (numFeature):
		featureList=[row[i] for row in dataset]	#等价于fetureList=dataset[:,i],但此时的dataset需为numpy的array列表,而非list列表
		#这个列表推导式每次将dataset一行的第i项元素加入列表,整个dataset遍历完,featureList就接收到了由dataset第i列组成的列表
		featureSetList=set(featureList)		#set(list)返回一个列表,列表里的元素是对list求集合的结果,即去除重复值。
		newEnt=0
		for value in featureSetList:
			subdataset=splitDataSet(dataset,i,value)
			prob=len(subdataset)/len(dataset)
			newEnt+=prob*calShannonEnt(subdataset)	#熵的期望=子集的熵*子集占整个数据集的比重
		infoGain=baseEnt-newEnt				#熵越小越好,所以infoGain(信息增益)越大越好
		if infoGain>bestInfoGain:
			bestInfoGain=infoGain
			bestFeature=i
	return bestFeature



#输入为标签列表,输出为出现次数最多的标签
def majorityCnt(labellist):		#labellist是dataset或dataset子集的最后一列,标签列
	labelcount={}
	for label in labellist:
		labelcount[label]=labelcount.get(label,0)+1
	sortedlabelcount=sorted(labelcount.items(),key=operator.itemgetter(1),reverse=True)
	return sortedlabelcount[0][0]



#构造决策树
#注意:原代码中第二个参数不是attrilist而是labels,但labels易于fish的'yes'和'no'标签混淆,故换成attrilist表示'no surfacing','flipper'这样的列属性
def createTree(dataset,attrilist):
	labellist=[row[-1] for row in dataset]				#labelslist 标签列表
	if labellist.count(labellist[0])==len(labellist):	#list.count(x)	计数x在列表list中出现的次数。	
		return labellist[0]				#递归结束条件一:这里如果'=='成立,表示该数据集中的所有label(fish)完全相同,无需继续划分数据集,直接返回该标签。
	if len(dataset[0])==1:
		return majorityCnt(labellist)	#递归结束条件二:使用完了所有特征,仍不能将数据集划分成一致的label,返回出现次数最多的标签代表这个分组
	bestFeature=chooseBestFeatureToSplit(dataset)		#找到香农熵最小的那一个属性的index
	bestFeatureLabel=attrilist[bestFeature]		
	myTree={bestFeatureLabel:{}}		#myTree是一个字典,字典的键是属性,字典的值是另一个字典
	subattrilist=attrilist[:]		#列表作为参数传递的是列表的引用,故构造subattribute防止attributelist改变
	del(subattrilist[bestFeature])			#del(list[index]):列表list的删除函数,index就是attrilist要删掉的内容。
	featureValues=[row[bestFeature] for row in dataset]		#featureValues为bestfeature那项属性的所有属性值的列表
	featureValuesSet=set(featureValues)	#featureValuesSet为bestfeature那项属性可能出现的属性值的集合
	for value in featureValuesSet:
		myTree[bestFeatureLabel][value]=createTree(splitDataSet(dataset,bestFeature,value),subattrilist)	
		#createTree可能返回label作为值,也可能继续返回嵌套的字典作为值
		#{属性:{属性值种类1:可能返回一个标签,属性值种类二:可能返回一个嵌套字典}}
	return myTree	

#createTree和plotTree只是构造了决策树并将其可视化,但对输入的向量进行分类返回其label就要交给classify
#比如已经求得决策树,那么对[1,1,label?],其vector=[1,1],那么它的label是多少就需要借助classify函数。
def classify(tree,attrilist,vector):
	firstattri=list(tree.keys())[0]
	index=attrilist.index(firstattri)
	key=vector[index]
	value=tree[firstattri][key]
	if type(value).__name__=='dict':
		return classify(value,attrilist,vector)
	else:
		return value
'''
isinstance(object,type)函数:
>>> isinstance(1, int)
True
>>> isinstance(1.0, float)
True
>>>isinstance(1,(int,float))
True
注意:isinstance与type区别:type()的话类型必须一致,而在isinstance()中object可以与type一致,也可以是type的子类
'''			


#由于计算决策树的开销很大,所以使用pickle模块序列化对象,在需要的时候可以再读取出来。任何对象都可以序列化,字典也不例外。
def storeTree(inputTree,filename):
	import pickle
	fw = open(filename,'wb')
	pickle.dump(inputTree,fw)
	fw.close() 
def grabTree(filename):
	import pickle
	fr = open(filename,'rb')
	return pickle.load(fr)
'''
pickle模块:
如字典,列表以及自定义的结构的对象,程序关闭后消失,再使用就要重新构造。
pickle可以将这些对象存储在文件中,需要时直接读取,并且可以被识别还原成对象,而不用重新构造。
首先要存储,打开文件,用dump函数写入。
然后读取,打开文件,用load函数加载到程序里。
注意,读和写都是二进制。
'''



'''
dataset,attrilist=createDataSet()
print(dataset)
print(calShannonEnt(dataset))		
bestFeature=chooseBestFeatureToSplit(dataset)
print(bestFeature)
mytree=createTree(dataset,attrilist)
print(mytree)
print(classify(mytree,attrilist,[1,0]))
storeTree(mytree,'fish')
treeoffish=grabTree('fish')
print(treeoffish)
'''
		import matplotlib.pyplot as plt
import numpy as np
import trees

decisionNode = dict(boxstyle="sawtooth", fc="0.8")		#boxstyle 注释的边框类型,fc facecolor则为边框填充的颜色
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")						#arrowstyle 箭头类型

#添加注释的函数,将annotate这个函数进行了包装
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',
            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
#annotate(s,xy,xytext) 添加注释,s为添加的字符,xy为多要标注点的位置,xytext为注释的位置,arrowprops为箭头类型,不写则只有注释没有箭头。
#xycoords为箭头尾的位置信息,一般不用设置。textcoords为箭头和注释的位置信息,一般也不用设置。	   
#va和ha是标签的对齐方式,默认是标签的左下对齐xytext。若xytext=(x,y),则va为center表明标签中心的纵坐标为y,ha为center则标签的中心的横坐标的x。
#bbox为边框,boxstyle为边框的形状,fc facecolor为边框填充色



#求myTree的宽度(叶子节点数),叶子节点数就是mytree中出现的标签的数目
def NumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]		#获得myTree的键的列表的第一个键
    secondDict = myTree[firstStr]	#获得第一个键的值,是一个嵌套的字典或一个字典
    for key in secondDict.keys():	#遍历这个有字典组成的值
        if type(secondDict[key]).__name__=='dict':		#逐个检查这个字典的每个值的类型是否为字典。注意,
            numLeafs += NumLeafs(secondDict[key])
        else:   numLeafs +=1
    	#值分为两种情况,一种是label,遇到一个label加一个;一种是字典,那就递归,看里面包含了多少label。等遍历完这个myTree,叶子节点的数目就知道了。    
    return numLeafs
#修正后myTree的宽度为叶子节点数-1
def getNumLeafs(myTree):
	return NumLeafs(myTree)-1

#求myTree的深度
def getTreeDepth(myTree):
    maxDepth = 0 
    firstStr = list(myTree.keys())[0]
    #源代码为firstStr = myTree.keys()[0],但在python3.6中dict的items,keys操作返回的是view,要用index索引必须先将其转换成list
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        		#求每个分支的深度,叶子节点深度为1,每有一个字典就说明分支一次,则深度加1。但是根节点的深度没有算入,故3层的树深度为2。
        if thisDepth > maxDepth: maxDepth = thisDepth		#树的深度为最大的分支的深度
    return maxDepth


#在父节点和子节点的箭头的中心位置添加属性值信息
def plotMidText(cntrPt, parentPt, txtString):		#cntrPt 子节点	parentPt 父节点	txtString 属性值	 
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)	#rotation为文本旋转的角度	#text()添加纯文本信息

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]     
    cntrPt = (plotTree.xOff + numLeafs/2.0/plotTree.totalW, plotTree.yOff)	#x轴:基点xOff+偏移量(叶子节点数占totalW的比例,该比例除以2即为偏移量)
    #cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)	#源码
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':		#键为字典,xOff不变。但递归该字典时,其下的叶子节点一定会使xOff改变
            plotTree(secondDict[key],cntrPt,str(key))       
        else:   
            #plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW	#源码为先更改xOff值
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW		#源码中xOff为负值,故须先更改xOff的位置。经更改后现在可以反过来。
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
'''
绘图的说明:
图像是(0,0)到(1,1)的正方形,xOff和yOff是每次绘图的基点。
y轴很简单,将图形划分成totalD份,基点yOff从1开始,yOff每次的偏移量设为1/totalD即可。
xOff较难理解。比如整个树有三个叶子节点,那么每个叶子节点的横坐标是多少?1/3,2/3,1?不是的,应该是1/6,1/2,5/6。否则整个图形就偏右了。
但在y轴就不存在偏上的问题,为什么?三层的树totalD=2,yOff从1开始,每次下降1/totalD=0.5,分别是1,0.5,0,图形没有往上跑。
源码中将xOff从-1/6开始,基于此先确定了偏移量,再确定子节点的计算,最后确定了xOff的值为-1/6。这不是很好理解
因此,我直接仿造y轴的方法,也将totalW自减了1,这样,xOff就可以从0开始了。
这样子做的话,公式很易于理解。但缺点是图示较源码的宽一些,美观性略差,但基本影响不大。
源码中numLeafs+1的含义是:它的这个图x轴的边界不是(0,1),xOff是从-1/totalW开始的,即边界为(-1/totalW,1+1/totalW),所以要加1。
'''


def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = fig.add_subplot(111, frameon=False, **axprops)    #**axprops就没有了刻度,frameon=False就没有了边框
    plotTree.totalW = getNumLeafs(inTree)
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = 0; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')		
    #在第一次调用时,'no surfacing'属性位置与父节点(0.5,1)位置相同,故不产生箭头,且参数nodeText为空,这使得'no surfacing'的父节点形同虚设
    plt.show()
'''
def createPlot():
    fig = plt.figure(1, facecolor='white')
    fig.clf()		#clf() 清除fig之前的图像
    createPlot.ax1 = fig.add_subplot(111, frameon=False)
    plotNode('decisionNode',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('leafNode',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()
'''

#createPlot(trees.mytree)      import trees
import treeplotter

path='D:\ADA\save\python\MachineLearninginAction\machinelearninginaction\Ch03\lenses.txt'
fr=open(path,'r')
dataset=[line.strip().split('\t') for line in fr.readlines()]	
#.strip(rm)        删除s字符串中开头、结尾处,出现在rm删除序列的字符
attributelist=['age','prescript','astigmatic','tearRate']
lensesTree=trees.createTree(dataset, attributelist)
treeplotter.createPlot(lensesTree)


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值