1 #coding=utf-8
2 from math importlog3 importoperator4 defcalcShannonEnt(dataSet):5 numEntries =len(dataSet)6 labelCounts ={}7 for featVec indataSet:8 currentLabel = featVec[-1] #提取类标号的属性值
9 #把类标号不同的属性值及其个数存入字典中
10 if currentLabel not inlabelCounts .keys():11 labelCounts [currentLabel ]=012 labelCounts [currentLabel]+=1
13 shannonEnt = 0.0
14 #计算类标号的平均信息量,如公式中H(S)
15 for key inlabelCounts :16 prob = float(labelCounts [key])/numEntries17 shannonEnt -= prob * log(prob,2)18 returnshannonEnt19
20 defcreateDataSet():21 dataSet = [[1, 1, 'yes'],22 [1, 1, 'yes'],23 [1, 0, 'no'],24 [0, 1, 'no'],25 [0, 1, 'no']]26 labels = ['no surfacing','flippers']27 #change to discrete values
28 returndataSet, labels29 defcreateDataSet1():30 dataSet = [[u'小于等于5',u'高',u'否',u'一般',u'否'],31 [u'小于等于5', u'高', u'否', u'好', u'否'],32 [u'5到10', u'高', u'否', u'一般', u'否'],33 [u'大于等于10', u'中', u'否', u'一般', u'是'],34 [u'大于等于10', u'低', u'是', u'一般', u'是'],35 [u'5到10', u'中', u'否', u'好', u'否'],36 [u'5到10', u'高', u'是', u'一般', u'是'],37 [u'小于等于5', u'中', u'否', u'一般', u'否'],38 [u'5到10', u'中', u'否', u'好', u'否'],39 [u'大于等于10', u'高', u'是', u'好', u'是'],40 [u'5到10', u'低', u'是', u'一般', u'是'],41 [u'小于等于5', u'中', u'是', u'一般', u'是'],42 [u'小于等于5', u'低', u'是', u'一般', u'是'],43 [u'大于等于10', u'中', u'是', u'好', u'是']]44 labels = [u'役龄',u'价格',u'是否关键部件',u'磨损程度']45 returndataSet ,labels46
47 #按照给定特征划分数据集,把符合给定属性值的对象组成新的列表
48 defsplitDataSet(dataSet,axis,value):49 retDataSet =[]50 for featVec indataSet:51 #选择符合给定属性值的对象
52 if featVec[axis] ==value:53 reduceFeatVec = featVec[:axis] #对对象的属性值去除给定的特征的属性值
54 reduceFeatVec.extend(featVec[axis+1:])55 retDataSet.append(reduceFeatVec ) #把符合且处理过的对象添加到新的列表中
56 returnretDataSet57
58 #选取最佳特征的信息增益,并返回其列号
59 defchooseBestFeaturesplit(dataSet):60 numFeatures = len(dataSet[0])-1 #获得样本集S 除类标号之外的属性个数,如公式中的k
61 baseEntropy = calcShannonEnt(dataSet) #获得类标号属性的平均信息量,如公式中H(S)
62
63 bestInfoGain = 0.0 #对最佳信息增益的初始化
64 bestFeature = -1 #最佳信息增益的属性在样本集中列号的初始化
65
66 #对除类标号之外的所有样本属性一一计算其平均信息量
67 for i inrange(numFeatures ):68 featList = [example[i] for example in dataSet] #提取第i 个特征的所有属性值
69 uniqueVals = set(featList ) #第i 个特征所有不同属性值的集合,如公式中 aq
70 newEntropy = 0.0 #对第i 个特征的平均信息量的初始化
71 #计算第i 个特征的不同属性值的平均信息量,如公式中H(S| Ai)
72 for value inuniqueVals:73 subDataSet = splitDataSet(dataSet,i,value ) #提取第i 个特征,其属性值为value的对象集合
74 prob = len (subDataSet )/float(len(dataSet)) #计算公式中P(Cpq)的概率
75 newEntropy += prob * calcShannonEnt(subDataSet ) #第i个特征的平均信息量,如 公式中H(S| Ai)
76 infoGain = baseEntropy - newEntropy #第i 个的信息增益量
77 if (infoGain > bestInfoGain ): #选取最佳特征的信息增益,并返回其列号
78 bestInfoGain =infoGain79
80 bestFeature =i81 returnbestFeature82
83 #选择列表中重复次数最多的一项
84 defmajorityCnt(classList):85 classCount={}86 for vote inclassList :87 if vote not inclassCount .keys():88 classCount [vote] =089 classCount[vote] += 1
90 sortedClassCount =sorted(classCount.iteritems() ,91 key=operator.itemgetter(1),92 reverse= True ) #按逆序进行排列,并返回由元组组成元素的列表
93 returnsortedClassCount[0][0]94
95 #创建决策树
96 defcreateTree(dataSet,labels):97 Labels = labels [:] #防止改变最初的特征列表
98 classList = [example[-1] for example in dataSet ] #获得样本集中的类标号所有属性值
99 if classList.count(classList [0]) == len(classList): #类标号的属性值完全相同则停止继续划分
100 returnclassList[0]101 if len(dataSet[0]) == 1: #遍历完所有的特征时,仍然类标号不同的属性值,则返回出现次数最多的属性值
102 returnmajorityCnt(classList)103 bestFeat = chooseBestFeaturesplit(dataSet) #选择划分最佳的特征,返回的是特征在样本集中的列号
104 bestFeatLabel = Labels[bestFeat] #提取最佳特征的名称
105 myTree = {bestFeatLabel :{}} #创建一个字典,用于存放决策树
106 del(Labels[bestFeat]) #从特征列表中删除已经选择的最佳特征
107 featValues = [example[bestFeat] for example in dataSet ] #提取最佳特征的所有属性值
108 uniqueVals = set(featValues ) #获得最佳特征的不同的属性值
109 for value inuniqueVals :110 subLabels = Labels[:] #把去除最佳特征的特征列表赋值于subLabels
111 myTree [bestFeatLabel][value] =createTree(splitDataSet(dataSet ,bestFeat ,value ),112 subLabels ) #递归调用createTree()
113 returnmyTree114
115 #决策树的存储
116 defstoreTree(inputTree,filename):117 importpickle118 fw = open(filename,'w')119 pickle.dump(inputTree ,fw)120 fw.close()121
122 defgrabTree(filename):123 importpickle124 fr =open(filename)125 returnpickle.load(fr)126
127
128 #使用决策树的分类函数
129 defclassify(inputTree,featLabels,testVec):130 firstStr = inputTree.keys()[0] #获得距离根节点最近的最佳特征
131 secondDict = inputTree[firstStr ] #最佳特征的分支
132 featIndex = featLabels .index(firstStr) #获取最佳特征在特征列表中索引号
133 for key in secondDict .keys(): #遍历分支
134 if testVec [featIndex ] == key: #确定待查数据和最佳特征的属性值相同的分支
135 if type(secondDict [key]).__name__ == 'dict': #判断找出的分支是否是“根节点”
136 classLabel = classify(secondDict[key],featLabels ,testVec) #利用递归调用查找叶子节点
137 else:138 classLabel = secondDict [key] #找出的分支是叶子节点
139 return classLabel