决策树

一、前沿

决策树是一种非常常用的机器学习算法,可以应用于分类和回归中,其中比较著名的有三种:ID3、C4.5和Cart算法。对于前两种只能针对分类,即离散数据集,且可以是多叉分类树;最后一种CART算法是分类决策树,既可以用于分类树,也可以用于回归树。决策树由节点和有向边组成。节点又分为内部节点和叶子节点。内部节点表示一个特征或属性;叶子节点表示一个分类。

接下来通过一一介绍这三个算法来分析决策树算法。

二、重要概念

2.1 熵(Entropy)

表示随机变量的不确定性的度量。熵越大,则其随机变量的不确定性也越大。

设:X是取有限值得离散随机变量,其概率分布是:


则随机变量X的熵定义为:

          

2.2  条件熵

条件熵H(Y|X)表示在已知随机变量X的条件下,随机变量Y的不确定性;即随机变量X给定的条件下,变量Y的条件熵。


这个表示在某种特征的条件下,所得到的条件熵。

2.3  经验熵和经验条件熵

当熵和条件熵中的概率由数据估计(特别是极大似然估计)得到时,所对应的熵和条件熵分别经验熵和经验条件熵。

2.4  信息增益

当训练集合D的经验熵H(D)与特征A给定条件下的经验条件熵H(D|A)之差:

g(D,A) = H(D) - H(D|A)

2.5  信息增益率

信息增益g(D,A)与训练数据集D关于特征A的值的熵之比。即


其中,

而ID3采用信息增益来作为划分依据,C4.5采用信息增益率来作为划分的依据。

2. 6  基尼指数

也是用来衡量数据不确定性的一个指标,在分类问题中,假设有K个类,样本点属于第K类的概率为Pk,则基尼指数定义为:


CART算法在分类问题中,基尼指数作为特征选择的依据,选择基尼指数最小的特征及切分点作为最优特征和最优切分点。

在回归问题中,特征选择及最佳划分特征值的依据是:划分后的样本的均方差之和最小。

三、算法分析

清楚决策树停止划分的情况:

1、当前数据集的属性值为空。

2、当前所有样本的类别相同。

3、信息增益小于一定的值。(预剪枝中使用)


算法实现主要包括两大模块:最优划分特征选择和决策树的生成。

最优划分特征选择部分

1、计算信息熵: 
输入:数据集(含标签) 
输出:信息熵 



2、划分数据集 
输入:数据集、属性(特征)、指定特征值value 
输出:具有该指定特征值的所有数据集 
重难点:数据切分,选出指定的数据!首先找到所有属性的值是value 的样本,然后去除该属性。组成返回数据集!!!在决策树生成部分的步骤e中很重要。

3、选择最佳特征 
输入:数据集(含类别标签)、所有属性值 
输出:最佳特征 
a、计算数据集总的信息熵 Entropy 
b、分别遍历每一个属性(特征)(一个属性可能有多个特征值),计算每个属性下的信息增益(包含多个特征值下的信息熵)Em 
c、寻找最大的(Entropy - Em),此时即是最大的信息增益,选择的特征既是最佳分类特征。

4、投票表决 
输入:数据集 
输出:类别标签 
统计该数据集下每个类别出现的次数,排序,返回出现次数最多的类别。

决策树的生成(该函数是一个递归的过程)CreateTree

输入:数据集、特征 
输出:字典型数据——决策树 
a、判断是否满足停止划分的条件 
若当前数据集的属性值为空,则投票表决当前样本中最多的类别 
若当前所有的样本类别相同,则返回当前数据的类别。

b、寻找当前数据的最佳划分特征 
c、将最佳特征作为关键字,保存到字典中 
d、从当前的属性集合中删除该最佳特征 
e、遍历该最佳划分特征的所有属性值feat,循环调用函数 CreateTree(输入参数为:最佳特征值为feat的所有数据集,去除最佳特征的属性集合)

四、代码注意: 
1、生成的决策树用字典保存,并且每个关键字的值是一个字典; 
2、生成的决策树可以用 pickle 序列化对象保存; 
3、ID3 算法适用于标称型数据,在函数的输入、输出中,数据类型为 list

五、代码

[python]  view plain  copy
 print ?
  1. #-*- coding:utf-8 -*-  
  2. import numpy as np  
  3. from numpy import *  
  4. import pandas as pd  
  5. from math import *  
  6. import operator  
  7. import pickle              # 使用该模块实现对决策树的保存  
  8.   
  9. # 数据导入  
  10. def loadData(fileName):  
  11.     dataSet = []  
  12.     fr = open(fileName)  
  13.     for featVector in fr.readlines():  
  14.         lineVector = featVector.strip().split('\t')  
  15.         dataSet.append(lineVector)  
  16.     return dataSet  
  17.   
  18. def calcuEntropy(myData):     # 计算信息熵  
  19.     numSample = len(myData)  
  20.     myClassCount = {}  
  21.     for featVector in myData:  
  22.         theKey = featVector[-1]  
  23.         if theKey not in myClassCount:  
  24.             myClassCount[theKey] = 0  
  25.         myClassCount[theKey] += 1  
  26.     myEntropy = 0  
  27.     for Keys in myClassCount.keys():  
  28.         Px = float(myClassCount[Keys])/numSample  
  29.         myEntropy -= Px*log(Px,2)            # 需要导入 math 库  
  30.     return myEntropy  
  31.   
  32. # 划分数据集:返回划分好的数据集  
  33. def splitDataSet(dataX,FeatureNumber,value):  # 输入:数据集、第i 个特征、该属性的值  
  34.     retMat = []  
  35.     for featVect in dataX:  
  36.         if featVect[FeatureNumber] == value:  
  37.             x1 = featVect[:FeatureNumber]  
  38.             x2 = featVect[FeatureNumber+1:]  
  39.             x1.extend(x2)  
  40.             retMat.append(x1)  
  41.     #print"retMat", retMat  
  42.     return retMat  
  43.   
  44.  # 计算最优特征:计算每个特征下的信息熵,信息熵最大的既是最优特征,返回的数字 i 代表第 i 个特征  
  45. def GetBestFeature(dataM):  
  46.     BestFeat = -1; LargestInformGain = -1      # 最佳特征、最大信息增益  
  47.     theEntropy = calcuEntropy(dataM)  
  48.     FeatNumber = len(dataM[0])-1  
  49.     for i in range(FeatNumber):  
  50.         FeatList = [example[i] for example in dataM]  # 统计每个特征有几个特征值  
  51.         FeatUnique = set(FeatList)              # 每个特征中的特征值,计算每个特征值下的信息增益  
  52.         NewEntropy = 0.0  
  53.         for j in FeatUnique:  
  54.             retMat = splitDataSet(dataM,i,j)     # 得到满足条件的数据  
  55.             Prob =  len(retMat)/float(len(dataM))  
  56.             NewEntropy -= Prob*calcuEntropy(retMat)         # 注意:这里是子数据集的概率x 该数据集的熵  
  57.         informGain = theEntropy + NewEntropy  
  58.         if informGain > LargestInformGain:  
  59.             LargestInformGain = informGain  
  60.             BestFeat = i  
  61.     return BestFeat  
  62.   
  63. def RoleOfVote(dataM):           # 投票规则:少数服从多数  
  64.     lables = [example[-1for example in dataM]  
  65.     lablesCount = {}  
  66.     for i in lables:  
  67.         if i not in lablesCount.keys():  
  68.             lablesCount[i] = 0  
  69.         lablesCount[i] += 1  
  70.     theSort = sorted(lablesCount.iteritems(),key =operator.itemgetter(1),reverse=True)  
  71.     return theSort[0][0]                # 返回出现次数最多的类别标签  
  72.   
  73.   
  74. # 决策树生成: 首先,判断是否满足停止划分的条件:1、所有的类别标签相同  2、属性值为空  
  75. def CreateTree(dataSet,label):  
  76.     allLabels = [example[-1for example in dataSet]  
  77.     #print "调用几次"  
  78.     if allLabels.count(allLabels[0])==len(allLabels):  
  79.         return allLabels[0]  
  80.     if(len(dataSet[0])==1):                # 没有属性可以划分时,采用投票规则  
  81.         return RoleOfVote(dataSet)  
  82.     featNumber = GetBestFeature(dataSet)      # 返回最佳划分的编号  
  83.     #print featNumber  
  84.     bestFeature = label[featNumber]  
  85.     myTree = {bestFeature:{}}  
  86.     del(label[featNumber])  
  87.     featValue =  [example[featNumber] for example in dataSet]  
  88.     uniqueFeatValues = set(featValue)  
  89.     for i in uniqueFeatValues:  
  90.         subLabels = label[:]  
  91.         myTree[bestFeature][i] = CreateTree(splitDataSet(dataSet, featNumber, i), subLabels)  
  92.     return myTree  
  93.   
  94. # 对输入样本进行分类  
  95. def SampleClassify(inputTree, featLabels, testVec):            # 输入树、属性标签、测试向量  
  96.     firstStr = inputTree.keys()[0]                           # 决策树的第一个关键词(第一个划分的属性)  
  97.     secondDict = inputTree[firstStr]                          # 决策树每个关键字对应的值也是字典  
  98.     featIndex = featLabels.index(firstStr)                  # 得到该特征在属性标签中的编号 K  
  99.     key = testVec[featIndex]                                # 得到当前测试向量中第 K 号特征值  
  100.     valueOfFeat = secondDict[key]  
  101.     if isinstance(valueOfFeat, dict):  
  102.         classLabel = SampleClassify(valueOfFeat, featLabels, testVec)  
  103.     else:  
  104.         classLabel = valueOfFeat  
  105.     return classLabel  
  106.   
  107. # 将生成的树保存  
  108. def storeTree(inputTree, filename):  
  109.     fw = open(filename, 'w')  
  110.     pickle.dump(inputTree, fw)  
  111.     fw.close()  
  112.   
  113. # 加载保存的树  
  114. def loadTree(filename):  
  115.     fr = open(filename)  
  116.     return pickle.load(fr)  
  117.   
  118. if __name__=="__main__":  
  119.     print "hello world"  
  120.     dataSet = loadData('lenses.txt')  
  121.     labels = ['age','prescript','astigmatic','tearRate']  
  122.     slabels = labels[:]              # pyhton 函数中的参数是按照引用方式传递的,为防止labels 改变,复制类标签带入  
  123.     myTree = CreateTree(dataSet,slabels)  
  124.     print labels  
  125.     #storeTree(myTree,"xu")              # 保存决策树  
  126.     test = dataSet[0]  
  127.     #print test[:-1]  
  128.     print"预测结果是:", SampleClassify(myTree,labels,test[:-1])  
  129.     print"真是标签是:",test[-1]  

六、CART算法

1、特征选择(依据:总方差最小)

输入:数据集、op = [m,n] 
输出:最佳特征、最佳划分特征值

m表示剪枝前总方差与剪枝后总方差差值的最小值; n: 数据集划分为左右两个子数据集后,子数据集中的样本的最少数量;

1、判断数据集中所有的样本标签是否相同,是:返回当前标签; 
2、遍历所有的样本特征,遍历每一个特征的特征值。计算出每一个特征值下的数据总方差,找出使总方差最小的特征、特征值 
3、比较划分前和划分后的总方差大小;若划分后总方差减少较小,则返回的最佳特征为空,返回的最佳划分特征值会为当前数据集标签的平均值。 
4、比较划分后的左右分支数据集样本中的数量,若某一分支数据集中样本少于指定数量op[1],则返回的最佳特征为空, 
返回的最佳划分特征值会为当前数据集标签的平均值。 
5、否则,返回使总方差最小的特征、特征值

二、回归树的生成函数 createTree 
输入:数据集 
输出:生成回归树 
1、得到当前数据集的最佳划分特征、最佳划分特征值 
2、若返回的最佳特征为空,则返回最佳划分特征值(作为叶子节点) 
3、声明一个字典,用于保存当前的最佳划分特征、最佳划分特征值 
4、执行二元切分;根据最佳划分特征、最佳划分特征值,将当前的数据划分为两部分 
5、在左子树中调用createTree 函数, 在右子树调用createTree 函数。 
6、返回树。

注:在生成的回归树模型中,划分特征、特征值、左节点、右节点均有相应的关键词对应。

三、(后)剪枝:(CART 树一定是二叉树,所以,如果发生剪枝,肯定是将两个叶子节点合并)

输入:树、测试集 
输出:树

1、判断测试集是否为空,是:对树进行塌陷处理 
2、判断树的左右分支是否为树结构,是:根据树当前的特征值、划分值将测试集分为Lset、Rset两个集合; 
3、判断树的左分支是否是树结构:是:在该子集递归调用剪枝过程; 
4、判断树的右分支是否是树结构:是:在该子集递归调用剪枝过程; 
5、判断当前树结构的两个节点是否为叶子节点: 
是: 
a、根据当前树结构,测试集划分为Lset,Rset两部分; 
b、计算没有合并时的总方差NoMergeError,即:测试集在Lset 和 Rset 的总方差之和; 
c、合并后,取叶子节点值为原左右叶子结点的均值。求取测试集在该节点处的总方差MergeError,; 
d、比较合并前后总方差的大小;若NoMergeError > MergeError,返回合并后的节点;否则,返回原来的树结构; 
否: 
返回树结构。

代码如下:

[python]  view plain  copy
 print ?
  1. #-*- coding:utf-8 -*-  
  2. from numpy import *  
  3. import numpy as np  
  4. # 三大步骤:  
  5. ''''' 
  6. 1、特征的选择:标准:总方差最小 
  7. 2、回归树的生成:停止划分的标准 
  8. 3、剪枝: 
  9. '''  
  10.   
  11. # 导入数据集  
  12. def loadData(filaName):  
  13.     dataSet = []  
  14.     fr = open(filaName)  
  15.     for line in fr.readlines():  
  16.         curLine = line.strip().split('\t')  
  17.         theLine = map(float, curLine)                 # map all elements to float()  
  18.         dataSet.append(theLine)  
  19.     return dataSet  
  20.   
  21. # 特征选择:输入:       输出:最佳特征、最佳划分值  
  22. ''''' 
  23. 1、选择标准 
  24. 遍历所有的特征Fi:遍历每个特征的所有特征值Zi;找到Zi,划分后总的方差最小 
  25. 停止划分的条件: 
  26. 1、当前数据集中的标签相同,返回当前的标签 
  27. 2、划分前后的总方差差距很小,数据不划分,返回的属性为空,返回的最佳划分值为当前所有标签的均值。 
  28. 3、划分后的左右两个数据集的样本数量较小,返回的属性为空,返回的最佳划分值为当前所有标签的均值。 
  29. 当划分的数据集满足上述条件之一,返回的最佳划分值作为叶子节点; 
  30. 当划分后的数据集不满足上述要求时,找到最佳划分的属性,及最佳划分特征值 
  31. '''  
  32.   
  33. # 计算总的方差  
  34. def GetAllVar(dataSet):  
  35.     return var(dataSet[:,-1])*shape(dataSet)[0]  
  36.   
  37. # 根据给定的特征、特征值划分数据集  
  38. def dataSplit(dataSet,feature,featNumber):  
  39.     dataL =  dataSet[nonzero(dataSet[:,feature] > featNumber)[0]]  
  40.     dataR = dataSet[nonzero(dataSet[:,feature] <= featNumber)[0]]  
  41.     return dataL,dataR  
  42.   
  43. # 特征划分  
  44. def choseBestFeature(dataSet,op = [1,4]):          # 三个停止条件可否当作是三个预剪枝操作  
  45.     if len(set(dataSet[:,-1].T.tolist()[0]))==1:     # 停止条件 1  
  46.         regLeaf = mean(dataSet[:,-1])           
  47.         return None,regLeaf                   # 返回标签的均值作为叶子节点  
  48.     Serror = GetAllVar(dataSet)  
  49.     BestFeature = -1; BestNumber = 0; lowError = inf  
  50.     m,n = shape(dataSet) # m 个样本, n -1 个特征  
  51.     for i in range(n-1):    # 遍历每一个特征值  
  52.         for j in set(dataSet[:,i].T.tolist()[0]):  
  53.             dataL,dataR = dataSplit(dataSet,i,j)  
  54.             if shape(dataR)[0]<op[1or shape(dataL)[0]<op[1]: continue  # 如果所给的划分后的数据集中样本数目甚少,则直接跳出  
  55.             tempError = GetAllVar(dataL) + GetAllVar(dataR)  
  56.             if tempError < lowError:  
  57.                 lowError = tempError; BestFeature = i; BestNumber = j  
  58.     if Serror - lowError < op[0]:               # 停止条件 2   如果所给的数据划分前后的差别不大,则停止划分  
  59.         return None,mean(dataSet[:,-1])           
  60.     dataL, dataR = dataSplit(dataSet, BestFeature, BestNumber)  
  61.     if shape(dataR)[0] < op[1or shape(dataL)[0] < op[1]:        # 停止条件 3  
  62.         return None, mean(dataSet[:, -1])  
  63.     return BestFeature,BestNumber  
  64.   
  65.   
  66. # 决策树生成  
  67. def createTree(dataSet,op=[1,4]):  
  68.     bestFeat,bestNumber = choseBestFeature(dataSet,op)  
  69.     if bestFeat==Nonereturn bestNumber  
  70.     regTree = {}  
  71.     regTree['spInd'] = bestFeat  
  72.     regTree['spVal'] = bestNumber  
  73.     dataL,dataR = dataSplit(dataSet,bestFeat,bestNumber)  
  74.     regTree['left'] = createTree(dataL,op)  
  75.     regTree['right'] = createTree(dataR,op)  
  76.     return  regTree  
  77.   
  78. # 后剪枝操作  
  79. # 用于判断所给的节点是否是叶子节点  
  80. def isTree(Tree):  
  81.     return (type(Tree).__name__=='dict' )  
  82.   
  83. # 计算两个叶子节点的均值  
  84. def getMean(Tree):  
  85.     if isTree(Tree['left']): Tree['left'] = getMean(Tree['left'])  
  86.     if isTree(Tree['right']):Tree['right'] = getMean(Tree['right'])  
  87.     return (Tree['left']+ Tree['right'])/2.0  
  88.   
  89. # 后剪枝  
  90. def pruneTree(Tree,testData):  
  91.     if shape(testData)[0]==0return getMean(Tree)  
  92.     if isTree(Tree['left'])or isTree(Tree['right']):  
  93.         dataL,dataR = dataSplit(testData,Tree['spInd'],Tree['spVal'])  
  94.     if isTree(Tree['left']):  
  95.         Tree['left'] = pruneTree(Tree['left'],dataL)  
  96.     if isTree(Tree['right']):  
  97.         Tree['right'] = pruneTree(Tree['right'],dataR)  
  98.     if not isTree(Tree['left']) and not isTree(Tree['right']):  
  99.         dataL,dataR = dataSplit(testData,Tree['spInd'],Tree['spVal'])  
  100.         errorNoMerge = sum(power(dataL[:,-1] - Tree['left'],2)) + sum(power(dataR[:,-1] - Tree['right'],2))  
  101.         leafMean = getMean(Tree)  
  102.         errorMerge = sum(power(testData[:,-1]-  leafMean,2))  
  103.         if errorNoMerge > errorMerge:  
  104.             print"the leaf merge"  
  105.             return leafMean  
  106.         else:  
  107.             return Tree  
  108.     else:  
  109.         return Tree  
  110.   
  111. # 预测  
  112. def forecastSample(Tree,testData):  
  113.     if not isTree(Tree): return float(tree)  
  114.     # print"选择的特征是:" ,Tree['spInd']  
  115.     # print"测试数据的特征值是:" ,testData[Tree['spInd']]  
  116.     if testData[0,Tree['spInd']]>Tree['spVal']:  
  117.         if isTree(Tree['left']):  
  118.             return forecastSample(Tree['left'],testData)  
  119.         else:  
  120.             return float(Tree['left'])  
  121.     else:  
  122.         if isTree(Tree['right']):  
  123.             return forecastSample(Tree['right'],testData)  
  124.         else:  
  125.             return float(Tree['right'])  
  126.   
  127. def TreeForecast(Tree,testData):  
  128.     m = shape(testData)[0]  
  129.     y_hat = mat(zeros((m,1)))  
  130.     for i in range(m):  
  131.         y_hat[i,0] = forecastSample(Tree,testData[i])  
  132.     return y_hat  
  133.   
  134. if __name__=="__main__":  
  135.     print "hello world"  
  136.     dataMat = loadData("ex2.txt")  
  137.     dataMat = mat(dataMat)  
  138.     op = [1,6]    # 参数1:剪枝前总方差与剪枝后总方差差值的最小值;参数2:将数据集划分为两个子数据集后,子数据集中的样本的最少数量;          
  139.     theCreateTree =  createTree(dataMat,op)  
  140.    # 测试数据  
  141.     dataMat2 = loadData("ex2test.txt")  
  142.     dataMat2 = mat(dataMat2)  
  143.     #thePruneTree =  pruneTree(theCreateTree, dataMat2)  
  144.     #print"剪枝后的后树:\n",thePruneTree  
  145.     y = dataMat2[:, -1]  
  146.     y_hat = TreeForecast(theCreateTree,dataMat2)  
  147.     print corrcoef(y_hat,y,rowvar=0)[0,1]              # 用预测值与真实值计算相关系数  
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值