决策树算法python实现_决策树之python实现ID3算法(例子)

1 #-*- coding: utf-8 -*-

2 from math importlog3 importoperator4 importpickle5 '''

6 输入:原始数据集、子数据集(最后一列为类别标签,其他为特征列)7 功能:计算原始数据集、子数据集(某一特征取值下对应的数据集)的香农熵8 输出:float型数值(数据集的熵值)9 '''

10 defcalcShannonEnt(dataset):11 numSamples =len(dataset)12 labelCounts ={}13 for allFeatureVector indataset:14 currentLabel = allFeatureVector[-1]15 if currentLabel not inlabelCounts.keys():16 labelCounts[currentLabel] =017 labelCounts[currentLabel] += 1

18 entropy = 0.0

19 for key inlabelCounts:20 property = float(labelCounts[key])/numSamples21 entropy -= property * log(property,2)22 returnentropy23

24 '''

25 输入:无26 功能:封装原始数据集27 输出:数据集、特征标签28 '''

29 defcreatDataSet():30 dataset = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,0,'no']]31 labels = ['no surfacing','flippers']32 returndataset,labels33

34 '''

35 输入:数据集、数据集中的某一特征所在列的索引、该特征某一可能取值(例如,(原始数据集、0,1 ))36 功能:取出在该特征取值下的子数据集(子集不包含该特征)37 输出:子数据集38 '''

39 defgetSubDataset(dataset,colIndex,value):40 subDataset = [] #用于存储子数据集

41 for rowVector indataset:42 if rowVector[colIndex] ==value:43 #下边两句实现抽取除第colIndex列特征的其他特征取值

44 subRowVector =rowVector[:colIndex]45 subRowVector.extend(rowVector[colIndex+1:])46 #将抽取的特征行添加到特征子数据集中

47 subDataset.append(subRowVector)48 returnsubDataset49

50 '''

51 输入:数据集52 功能:选择最优的特征,以便得到最优的子数据集(可简单的理解为特征在决策树中的先后顺序)53 输出:最优特征在数据集中的列索引54 '''

55 defBestFeatToGetSubdataset(dataset):56 #下边这句实现:除去最后一列类别标签列剩余的列数即为特征个数

57 numFeature = len(dataset[0]) - 1

58 baseEntropy =calcShannonEnt(dataset)59 bestInfoGain = 0.0; bestFeature = -1

60 for i in range(numFeature):#i表示该函数传入的数据集中每个特征

61 #下边这句实现抽取特征i在数据集中的所有取值

62 feat_i_values = [example[i] for example indataset]63 uniqueValues =set(feat_i_values)64 feat_i_entropy = 0.0

65 for value inuniqueValues:66 subDataset =getSubDataset(dataset,i,value)67 #下边这句计算pi

68 prob_i = len(subDataset)/float(len(dataset))69 feat_i_entropy += prob_i *calcShannonEnt(subDataset)70 infoGain_i = baseEntropy -feat_i_entropy71 if (infoGain_i >bestInfoGain):72 bestInfoGain =infoGain_i73 bestFeature =i74 returnbestFeature75

76 '''

77 输入:子数据集的类别标签列78 功能:找出该数据集个数最多的类别79 输出:子数据集中个数最多的类别标签80 '''

81 defmostClass(ClassList):82 classCount ={}83 for class_i inClassList:84 if class_i not inclassCount.keys():85 classCount[class_i] =086 classCount[class_i] += 1

87 sortedClassCount =sorted(classCount.iteritems(),88 key=operator.itemgetter(1),reverse =True)89 returnsortedClassCount[0][0]90

91 '''

92 输入:数据集,特征标签93 功能:创建决策树(直观的理解就是利用上述函数创建一个树形结构)94 输出:决策树(用嵌套的字典表示)95 '''

96 defcreatTree(dataset,labels):97 classList = [example[-1] for example indataset]98 #判断传入的dataset中是否只有一种类别,是,返回该类别

99 if classList.count(classList[0]) ==len(classList):100 returnclassList[0]101 #判断是否遍历完所有的特征,是,返回个数最多的类别

102 if len(dataset[0]) == 1:103 returnmostClass(classList)104 #找出最好的特征划分数据集

105 bestFeat =BestFeatToGetSubdataset(dataset)106 #找出最好特征对应的标签

107 bestFeatLabel =labels[bestFeat]108 #搭建树结构

109 myTree ={bestFeatLabel:{}}110 del(labels[bestFeat])111 #抽取最好特征的可能取值集合

112 bestFeatValues = [example[bestFeat] for example indataset]113 uniqueBestFeatValues =set(bestFeatValues)114 for value inuniqueBestFeatValues:115 #取出在该最好特征的value取值下的子数据集和子标签列表

116 subDataset =getSubDataset(dataset,bestFeat,value)117 subLabels =labels[:]118 #递归创建子树

119 myTree[bestFeatLabel][value] =creatTree(subDataset,subLabels)120 returnmyTree121

122 '''

123 输入:测试特征数据124 功能:调用训练决策树对测试数据打上类别标签125 输出:测试特征数据所属类别126 '''

127 defclassify(inputTree,featlabels,testFeatValue):128 firstStr =inputTree.keys()[0]129 secondDict =inputTree[firstStr]130 featIndex =featlabels.index(firstStr)131 for firstStr_value insecondDict.keys():132 if testFeatValue[featIndex] ==firstStr_value:133 if type(secondDict[firstStr_value]).__name__ == 'dict':134 classLabel =classify(secondDict[firstStr_value],featlabels,testFeatValue)135 else: classLabel =secondDict[firstStr_value]136 returnclassLabel137

138

139 '''

140 输入:训练树,存储的文件名141 功能:训练树的存储142 输出:143 '''

144 defstoreTree(trainTree,filename):145

146 fw = open(filename,'w')147 pickle.dump(trainTree,fw)148 fw.close()149 defgrabTree(filename):150

151 fr =open(filename)152 returnpickle.load(fr)153

154

155 if __name__ == '__main__':156 dataset,labels =creatDataSet()157 storelabels = labels[:]#复制label

158 trainTree =creatTree(dataset,labels)159 classlabel = classify(trainTree,storelabels,[0,1])160 print classlabel

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值