最近重温了本科学过的一个算法,决策树。想写篇博文整理一下,一则与大家分享,希望能帮到有需要的人。二来作为一个学习笔记,记录一下学习内容。为了让大家对这个算法有一个系统清晰的认识,这篇博文侧重于对算法整体流程以及算法核心部分进行阐述。决策树算法整体流程大致都一样,不同的是在进行特征选择时,特征选择策略不同,这篇博文以ID3算法进行特征选择来学习决策树。如果你已掌握决策树算法整个工作流程,想要详细学习其他特征选择策略以及有关决策树剪枝的内容,推荐学习资料 《统计学习方法》(李航著)
话不多说,进入正题!
算法介绍
直观理解
决策树,是有监督学习算法,主要用于分类及回归任务。决策树算法本质是从训练数据集中学习归纳出一组规则,要求该规则能对训练集进行正确分类。举个例子:表中数据,有眼睛颜色、头发颜色两个特征,黄种人、白种人、混血三个分类。
构建的决策树如下:
相较于表格数据,决策树是不是一目了然,(如果你不这样认为,换一种说法,相较于表格,决策树是不是显得数据分类更清晰一些?从表格中你可以一眼看出分类的规律吗,很不方便),决策树的作用可不是单单让人看着方便。从数据集中训练得到的最终决策树,当你输入测试实例, (眼镜颜色、头发颜色),从根节点开始进行决策,最终会到达一个叶子节点,得到输入数据对应的分类结果。也就是说,决策树可以进行预测。
决策树最上面的节点称为根节点,是整个决策树的开始,眼睛颜色是什么呢?问题答案决定数据下一步走向的分支,这是一个决策过程。每一个分支是一个新的决策节点或者是树的叶子节点。每个决策节点代表一个问题,也就是待分类对象的特征。比如”头发颜色“这一节点,头发颜色是什么颜色呢?这是一个新的决策。每个叶子节点代表一种分类结果。
沿着决策树从上到下遍历的过程中,每个节点都会进行测试。决策节点上,回答结果(不同的测试输出),导致不同的分支,最后会到达叶子结点。这个过程就是决策树进行决策的过程,利用若干个特征来判断实例的类别。
主要步骤
一、特征选择
二、决策树生成
三、决策树剪枝 (涉及到决策树优化问题,本期内容不做详细介绍)
特征选择,如上方的例子,我们在选取特征时,先使用哪一个呢?首先要明确,我们的目标是要通过归纳出的规则,使数据在较少的步骤内找到对应的分类。特征选择的准则不同,对应的特征选择的顺序可能会发生变化。
这里,以ID3算法为例来进行学习。ID3算法以信息增益为度量准则进行特征选择。信息增益是什么?在此我们需要清晰熵的概念,先别急,我们往下看。
“组织杂乱无章的数据的一种方法是使用信息论度量信息”。提到信息论,就不得不想到信息论之父——克劳德 香农(被公认为20世纪最聪明的人之一)。” 在香农写完信息论之后,冯诺依曼建议用 ” 熵 “ 这个术语,因为大家都不知道它是什么意思”, 看到这样,真想直呼:大牛们命名伟大成果的方式可真有趣,amazing!
熵,定义为信息的期望值。
如果一个样本X可能的分类为 {1,2,…, n},(第1类,第2类……第n类) ,对应的概率为p(Xi)(i=1,2,…,n),则样本X的的熵定义为
大家可以理解为,一个衡量数据混乱程度的一个物理量。一组数据,越混乱无序,则熵值越大,越规则有序,则熵值越小。理解了熵的概念后,信息增益就很容易明白了。当我们选择特征A对数据集进行划分,划分前后数据集熵值的变化就是信息增益。即
信息增益 = H(划分前)— H(后)
ID3算法,就是遍历所有特征,计算以当前特征划分数据集对应的信息增益,最后,选择信息增益较大的特征来进行本轮划分。每完成一次划分,需要在待选特征中删除此特征。
以海洋生物数据为例:
表中有两个特征,“不浮出水面是否可以生存”,“是否有脚蹼”,我们先选择哪个特征进行划分呢?这个问题我们暂时不急于解决,后面结合程序就会一目了然。
程序实现
计算数据集的熵
import operator
from numpy import *
from math import log
####构建数据集#######################
def createDataSet():
data=[[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
feature_label=['no surfacing','flippers']
return data,feature_label
#####计算数据集的香农熵#################
def calshang(dataSet):
numEntry=len(dataSet) #统计所有实例数
label_list=[example[-1] for example in dataSet] #遍历每一行元素,取每一行最后一个元素构成标签列表。['yes', 'yes', 'no', 'no', 'no']
classCount={} #创建字典,对数据集中分类情况进行统计。
for label in label_list:
if label not in classCount.keys():
classCount[label]=0
classCount[label]+=1
# print(classCount) #{'yes': 2, 'no': 3}
shang=0.0
for key in classCount: #遍历类别字典,计算熵值
prob=classCount[key]/numEntry #float(classCount[key])/numEntry
shang+=-prob*log(prob,2)
return shang
dataSet,feature_label=createDataSet()
print('原始数据集为 %s'%(dataSet))
print('特征名称为:%s'%(feature_label))
print("熵为:")
print(calshang(dataSet))
程序输出:
原始数据集为 [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
特征名称为:['no surfacing', 'flippers']
熵为:
0.9709505944546686
划分数据集函数
###########以某一特征X,划分数据集######################
'''
dataSet:划分前的数据集
featureX_axis:特征X的索引号
value:特征X的取值
reduced_dataSet:以特征X进行划分,返回特征X取值为value所有样本构成的数据集
'''
def split_dataSet(dataSet,featureX_axis,value):
reduced_dataSet=[] #划分后的数据集
for data in dataSet:
if data[featureX_axis]==value: #特征X值为value的,划分到一个子集中
reduced_data=data[0:featureX_axis]
reduced_data.extend(data[featureX_axis+1:]) #去除特征X对应的列
reduced_dataSet.append(reduced_data)
return reduced_dataSet
##测试函数功能
dataSet,feature_label=createDataSet()
split_dataSet(dataSet,1,1) #以索引为1的特征(是否有脚蹼)进行划分,返回特征对应为1的数据样本。注意:返回的数据中只剩下一个特征(索引为0的特征,即"不浮出水面是否可以生存"),因为以索引为1的特征进行划分,接下来的划分中不能再使用这个特征,所以数据会减少一个特征。
print(split_dataSet(dataSet,1,1))
程序输出
[[1, 'yes'], [1, 'yes'], [0, 'no'], [0, 'no']]
以索引为1的特征(是否有脚蹼)进行划分,返回对应特征值为1的数据样本。注意:返回的数据中只剩下一个特征(索引为0的特征,即"不浮出水面是否可以生存"),因为以索引为1的特征进行划分,接下来的划分中不能再使用这个特征,所以数据会减少一个特征。
特征选取(选择最好的数据集划分方式)
######################选择最好的数据集划分方式############
def chooseBestSplit(dataSet):
num_dataSet=len(dataSet) #计算划分前的数据集的实例个数
num_feature=len(dataSet[0])-1 #计算划分前的数据集的特征个数
base_shang=calshang(dataSet) #计算划分前的数据集的熵
bestInfoGain=0.0 #初始化信息增益为0
bestFeature=-1 #记录最佳划分的特征的索引
for i in range(num_feature): #遍历所有特征,寻求最佳特征划分
feature=[example[i] for example in dataSet] #得到第i个特征,对应的所有实例的数据
uniqueValue=set(feature) #第i个特征所有可能的取值
newEntropy=0.0
for value in uniqueValue: #遍历第i个特征所有可能的取值,计算按照第i个特征划分后数据集的信息熵
subDataSet=split_dataSet(dataSet,i,value)
prob=len(subDataSet)/float(num_dataSet)
newEntropy+=prob*calshang(subDataSet)
infoGain=base_shang-newEntropy #计算按照第i个特征划分后数据集的信息增益
print("以索引为"+str(i)+"的特征进行划分,信息增益为:"+str(infoGain))
if infoGain>bestInfoGain:
bestInfoGain=infoGain
bestFeature=i #记录最佳划分特征索引号
#print("bestFeature is %d" %(bestFeature))
return bestFeature
###函数功能测试
myDat,label=createDataSet()
bestFeature=chooseBestSplit(myDat)
print("bestFeature index is %d" %(bestFeature))
程序输出
以索引为0的特征进行划分,信息增益为:0.4199730940219749
以索引为1的特征进行划分,信息增益为:0.17095059445466854
bestFeature index is 0
到这里,我们先前的疑问就得到了解答,“不浮出水面是否可以生存”,“是否有脚蹼”,我们先选择哪个特征进行划分呢?先选择特征”不浮出水面是否可以生存“时,对应的信息增益要大于选择特征”是否有脚蹼"“。这样的解释似乎还不够直观,我们作图理解
按照第一个特征进行划分,特征值为0的分组中有两个鱼类,一个非鱼类。特征值为1的分组里全是非鱼类。按照第二个特征进行划分,第一个分组中有两个属于鱼类,两个属于非鱼类,另一个分组则只有一个非鱼类。第一种方式很好的处理了数据。就如同计算结果一样,选择第一个特征(索引为0)“不浮出水面是否可以生存”,划分前后熵值变化大,数据趋于有序的程度更大,故而选择第一个特征来进行划分。
构建决策树
################递归构建决策树##############3
#################最后决策###########################
#数据集已经处理了所有属性,但是类标签依然不是唯一,则少数服从多数
def majorityVote(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]+=1
sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True) #operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为一些序号(即需要获取的数据在对象中的序号),,reverse为Ture的时候降序排列
return sortedClassCount[0][0]
######################构建树##################
def createTree(dataSet,feature_label):
'''
dataSet:待创建树的数据集
feature_label:数据集中的特征标签
myTree:返回值,创建的树
'''
classList=[example[-1] for example in dataSet]
if len(set(classList))==1:return classList[0] #数据集中全是一种类别时,停止继续划分
if len(dataSet[1])==1:return majorityVote(classList) #数据集中所有特征都使用完毕时,返回出现次数最多的类别
bestFeature=chooseBestSplit(dataSet) #不满足以上条件时,需要对数据集继续划分,即在图中节点有分支。选取最好的特征
bestFeature_name=feature_label[bestFeature] #该特征对应的特征名称
del(feature_label[bestFeature]) #该特征在本次使用,在特征列表中删除
myTree={bestFeature_name:{}} #字典构建树
bestFeature_value=[example[bestFeature] for example in dataSet]
bestFeature_unique_value=set(bestFeature_value)
for value in bestFeature_unique_value:
sublabel=feature_label[:]
myTree[bestFeature_name][value]=createTree(split_dataSet(dataSet,bestFeature,value),sublabel)
return myTree
##函数功能测试
myDat,label=createDataSet()
print(createTree(myDat,label))
程序输出
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
输出结果就是构建完成的决策树,树结构以字典形式存储。接下来我们对其进行可视化。
绘制树形图
新建文件plot_tree.py,此段程序部分注释参照博客 https://blog.csdn.net/gaoyueace/article/details/78742579,在此特别感谢博主分享。
import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")#boxstyle为矩形框的类型,sawtooth表示锯齿形,fc:边框线粗细
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-") #箭头方向
def plotNode(nodeTxt,centerPt,parentPt,nodeType): #nodeTxt为要显示的文本,centerPt为文本的中心点,parentPt为箭头指向文本的点,xy是箭头尖的坐标
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords="axes fraction",va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
##xytest设置注释内容显示的中心位置,xycoords和textcoords是坐标xy与xytext的说明(按轴坐标),若textcoords=None,则默认textcoords与xycoords相同,若都未设置,默认为data
#va/ha设置节点框中文字的位置,va为纵向取值为(u'top', u'bottom', u'center', u'baseline'),ha为横向取值为(u'center', u'right', u'left')
###########获取叶节点的数目#################################
def getNumLeafs(myTree):
numLeafs=0
# firstStr=myTree.keys()[0]
firstStr=list(myTree.keys())[0] #当前树的keys,如:classcount={'gmy':{'English':98,'math':100,'biology':100,'computer':{'data_base':100,'data_structure':90}}} ,list(myTree.keys())[0]为gmy
#当classcount1={'gmy':98,'amy':100,'qz':92},list(myTree.keys())有三个元素,分别为,gmy,amy,qz
secondDict=myTree[firstStr] #获取子树(下一层字典)
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
numLeafs+=getNumLeafs(secondDict[key])
else: numLeafs+=1
return numLeafs
###########获取树的层数#################################
def getTreeDepth(myTree):
maxDepth=0
#firstStr=myTree.keys()[0] python3.6不支持字典index循环了
firstSides = list(myTree.keys())
firstStr = firstSides[0]#找到输入的第一个元素
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth=1+getTreeDepth(secondDict[key])
else:thisDepth=1
if thisDepth>maxDepth:maxDepth=thisDepth
return maxDepth
####################检索树#############################
'''
方便起见,函数中存储两棵已经构建好的决策树,用来直接测试可视化函数
'''
def retrieveTree(i):
listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
]
return listOfTrees[i]
#################plotTree函数##########################
def plotMidText(cntrPt,parentPt,txtString):
xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeText):
numleafs=getNumLeafs(myTree) #获取叶节点数量
depth=getTreeDepth(myTree) #获取树深度
firstStr=list(myTree.keys())[0]
cntrPt=(plotTree.xOff+(1.0+float(numleafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeText)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=myTree[firstStr]
plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
fig=plt.figure(1,facecolor='white')
fig.clf()
axprops=dict(xticks=[],yticks=[])
createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
plotTree.totalW=float(getNumLeafs(inTree))
plotTree.totalD=float(getTreeDepth(inTree))
plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()
###函数功能测试
myTree=retrieveTree(0)
createPlot(myTree)
myTree['no surfacing'][3]='maybe'
createPlot(myTree)
程序运行结果:
输入任意一棵决策树,上述绘图程序可根据树的高度和宽度来设置布局,使决策树有较为规整的呈现。例子中,给no surfacing添加一个属性maybe时,程序运行结果如下,多一个节点,整棵树布局依旧会保持合适在画布中央。
算法测试:使用决策树执行分类
#################使用决策树进行分类#####################
def classify(decisionTree,featLabel,testVec):
'''
decisionTree:构建好的决策树
featLabel:数据集特征名称列表 ['surfacing','flippers']
testVec:输入的测试向量,例如,想知道[1,1]的类别是什么
classify_label:返回的分类结果
'''
firstStr=list(decisionTree.keys())[0]
# print("first is %s" %(firstStr))
# print(featLabel)
index_feature=featLabel.index(firstStr) #得到当前特征在测试实例中的下标索引,方便将实例中该特征对应的值与决策树进行对比
secondDict=decisionTree[firstStr]
for key in secondDict.keys():
if key==testVec[index_feature]:#如果实例中特征对应的值与key值相同,则判断secondDict[key]是否为字典,若不是,则secondDict[key]即为分类结果。否则,继续向树下一层递进
if type(secondDict[key]).__name__=='dict':
classify_label=classify(secondDict[key],featLabel,testVec)
else:classify_label=secondDict[key]
return classify_label
myDat,label=createDataSet()
train_label=label[:]
myTree=createTree(myDat,train_label)
print("分类结果为:%s"%(classify(myTree,label,[1,0])))
程序输出:
分类结果为:no
测试实例 [1,0 ] 对应的分类结果为no
存储决策树
构建决策树,执行的是递归,这是很耗时间的。尤其当数据集很大时,将会耗费更多的时间。最好能在每次执行分类时就调用已经构建好的决策树。因此,为了解决这个问题,需要使用python模块pickle序列化对象,任何对象都可以执行序列化操作,序列化对象可以在磁盘上保存,并在需要的时候读出来。
################存储决策树######################
def storeTree(inputTree,filename):
import pickle ###这个模块可以序列化对象,序列化对象可以在磁盘上保存对象,并在需要的时候读取出来
fw=open(filename,'wb')
pickle.dump(inputTree,fw) #fw表示inputTree要写入的文件对象,fw必须以二进制可写模式打开,即“wb”
fw.close()
def grabTree(filename):
import pickle
fr=open(filename,'rb')
return pickle.load(fr) #fr必须以二进制可读模式打开,即“rb”,
myDat,label=createDataSet()
myTree=createTree(myDat,label)
storeTree(myTree,'classifierStorage.txt')
print(grabTree('classifierStorage.txt'))
程序执行完毕后,可以看到工作区内多了一个名为:classifierStorage的文件,里面存储的就是构建好的决策树。
读取文件,程序运行结果,
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
应用
”隐形眼镜数据集是非常著名的数据集,它包含了很多患者眼部状况的观察条件以及医生推荐的隐形眼镜类型。. 隐形眼镜类型包括硬材质 (hard)、软材质 (soft)以及不适合佩戴隐形眼镜 (no lenses)“,数据来源于UCI数据库。特征有四个:age(年龄)、prescript(症状)、astigmatic(是否散光)、tearRate(眼泪情况)
##########隐形眼镜分类#########################
fr=open('lenses.txt')
lenses_data=[example.strip().split('\t') for example in fr.readlines()]
lenses_label=['age','prescript','astigmatic','rearRate']
train_lense_label=lenses_label[:] ##构建树时,lenses_label会发生变化,因此在此处备份lenses—_label便于下面进行分类,传入train_lense_label。
myTree=createTree(lenses_data,train_lense_label)
print(myTree)
plot_tree.createPlot(myTree)
result=classify(myTree,lenses_label,['young','hyper','no','normal']) #测试,进行分类
print("分类结果为:%s"%(result))
程序运行结果
构建的决策树为:
{'rearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}, 'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}}}}}
字典看起来不直观,我们调用plot_tree中的createPlot()函数对其进行可视化
决策树绘制:
测试样例 [‘young’,‘hyper’,‘no’,‘normal’],分类结果:
分类结果为:soft
沿着决策树不同的分支,可以得到不同患者需要佩戴的隐形眼镜类型。从图中可以发现,我们不需要遍历太多的数据,进行四次决策,也就是说问四个问题就可以确定患者需要佩戴哪种类型的隐形眼镜。
算法存在的问题
1、 过度拟合,隐形眼镜的例子中可以看到,决策树很好的匹配了实验数据,然而这些匹配选项可能太多了。 一个好的模型不仅要能够很好地拟合训练数据,而且对未知样本也要能够准确地分类。过度拟合会导致模型的泛化能力下降。
2、决策树算法比较适合处理离散数值的属性。实际应用中属性是连续的或者离散的情况都比较常见。
算法优点
计算复杂度不高,输出的结果易于理解,对缺失值不敏感,可以处理不相关的特征数据。
以上就是决策树算法的主要内容,建议小伙伴们有时间的话动手实践一下。
独学无朋,则孤陋而难成,如果有不理解的地方,欢迎与我交流,我们可以共同学习!学海无涯,个人整理,内容难免会有纰漏,欢迎道友指正,感激不尽!
本文参考学习资料:《机器学习实战》(peter Harrington 著)
《统计学习方法》 《李航 著》