1 #-*- coding: utf-8 -*-
2 from math importlog3 importoperator4 importpickle5 '''
6 输入:原始数据集、子数据集(最后一列为类别标签,其他为特征列)7 功能:计算原始数据集、子数据集(某一特征取值下对应的数据集)的香农熵8 输出:float型数值(数据集的熵值)9 '''
10 defcalcShannonEnt(dataset):11 numSamples =len(dataset)12 labelCounts ={}13 for allFeatureVector indataset:14 currentLabel = allFeatureVector[-1]15 if currentLabel not inlabelCounts.keys():16 labelCounts[currentLabel] =017 labelCounts[currentLabel] += 1
18 entropy = 0.0
19 for key inlabelCounts:20 property = float(labelCounts[key])/numSamples21 entropy -= property * log(property,2)22 returnentropy23
24 '''
25 输入:无26 功能:封装原始数据集27 输出:数据集、特征标签28 '''
29 defcreatDataSet():30 dataset = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,0,'no']]31 labels = ['no surfacing','flippers']32 returndataset,labels33
34 '''
35 输入:数据集、数据集中的某一特征所在列的索引、该特征某一可能取值(例如,(原始数据集、0,1 ))36 功能:取出在该特征取值下的子数据集(子集不包含该特征)37 输出:子数据集38 '''
39 defgetSubDataset(dataset,colIndex,value):40 subDataset = [] #用于存储子数据集
41 for rowVector indataset:42 if rowVector[colIndex] ==value:43 #下边两句实现抽取除第colIndex列特征的其他特征取值
44 subRowVector =rowVector[:colIndex]45 subRowVector.extend(rowVector[colIndex+1:])46 #将抽取的特征行添加到特征子数据集中
47 subDataset.append(subRowVector)48 returnsubDataset49
50 '''
51 输入:数据集52 功能:选择最优的特征,以便得到最优的子数据集(可简单的理解为特征在决策树中的先后顺序)53 输出:最优特征在数据集中的列索引54 '''
55 defBestFeatToGetSubdataset(dataset):56 #下边这句实现:除去最后一列类别标签列剩余的列数即为特征个数
57 numFeature = len(dataset[0]) - 1
58 baseEntropy =calcShannonEnt(dataset)59 bestInfoGain = 0.0; bestFeature = -1
60 for i in range(numFeature):#i表示该函数传入的数据集中每个特征
61 #下边这句实现抽取特征i在数据集中的所有取值
62 feat_i_values = [example[i] for example indataset]63 uniqueValues =set(feat_i_values)64 feat_i_entropy = 0.0
65 for value inuniqueValues:66 subDataset =getSubDataset(dataset,i,value)67 #下边这句计算pi
68 prob_i = len(subDataset)/float(len(dataset))69 feat_i_entropy += prob_i *calcShannonEnt(subDataset)70 infoGain_i = baseEntropy -feat_i_entropy71 if (infoGain_i >bestInfoGain):72 bestInfoGain =infoGain_i73 bestFeature =i74 returnbestFeature75
76 '''
77 输入:子数据集的类别标签列78 功能:找出该数据集个数最多的类别79 输出:子数据集中个数最多的类别标签80 '''
81 defmostClass(ClassList):82 classCount ={}83 for class_i inClassList:84 if class_i not inclassCount.keys():85 classCount[class_i] =086 classCount[class_i] += 1
87 sortedClassCount =sorted(classCount.iteritems(),88 key=operator.itemgetter(1),reverse =True)89 returnsortedClassCount[0][0]90
91 '''
92 输入:数据集,特征标签93 功能:创建决策树(直观的理解就是利用上述函数创建一个树形结构)94 输出:决策树(用嵌套的字典表示)95 '''
96 defcreatTree(dataset,labels):97 classList = [example[-1] for example indataset]98 #判断传入的dataset中是否只有一种类别,是,返回该类别
99 if classList.count(classList[0]) ==len(classList):100 returnclassList[0]101 #判断是否遍历完所有的特征,是,返回个数最多的类别
102 if len(dataset[0]) == 1:103 returnmostClass(classList)104 #找出最好的特征划分数据集
105 bestFeat =BestFeatToGetSubdataset(dataset)106 #找出最好特征对应的标签
107 bestFeatLabel =labels[bestFeat]108 #搭建树结构
109 myTree ={bestFeatLabel:{}}110 del(labels[bestFeat])111 #抽取最好特征的可能取值集合
112 bestFeatValues = [example[bestFeat] for example indataset]113 uniqueBestFeatValues =set(bestFeatValues)114 for value inuniqueBestFeatValues:115 #取出在该最好特征的value取值下的子数据集和子标签列表
116 subDataset =getSubDataset(dataset,bestFeat,value)117 subLabels =labels[:]118 #递归创建子树
119 myTree[bestFeatLabel][value] =creatTree(subDataset,subLabels)120 returnmyTree121
122 '''
123 输入:测试特征数据124 功能:调用训练决策树对测试数据打上类别标签125 输出:测试特征数据所属类别126 '''
127 defclassify(inputTree,featlabels,testFeatValue):128 firstStr =inputTree.keys()[0]129 secondDict =inputTree[firstStr]130 featIndex =featlabels.index(firstStr)131 for firstStr_value insecondDict.keys():132 if testFeatValue[featIndex] ==firstStr_value:133 if type(secondDict[firstStr_value]).__name__ == 'dict':134 classLabel =classify(secondDict[firstStr_value],featlabels,testFeatValue)135 else: classLabel =secondDict[firstStr_value]136 returnclassLabel137
138
139 '''
140 输入:训练树,存储的文件名141 功能:训练树的存储142 输出:143 '''
144 defstoreTree(trainTree,filename):145
146 fw = open(filename,'w')147 pickle.dump(trainTree,fw)148 fw.close()149 defgrabTree(filename):150
151 fr =open(filename)152 returnpickle.load(fr)153
154
155 if __name__ == '__main__':156 dataset,labels =creatDataSet()157 storelabels = labels[:]#复制label
158 trainTree =creatTree(dataset,labels)159 classlabel = classify(trainTree,storelabels,[0,1])160 print classlabel