# -*- coding: UTF-8 -*- from math import log from numpy import * import matplotlib.pyplot as plt def calcShannonEnt(dataSet): numEntries=len(dataSet)#统计数据集的数量 labelCounts={}#创建一个数据字典 for featVec in dataSet: currentLabel=featVec[-1]#取数据集的最后一列数据 if currentLabel not in labelCounts.keys():#数据字典中是否存在当前键值 labelCounts[currentLabel]=0#数据字典中如果不存在则加入字典中 labelCounts[currentLabel]+=1#数据字典中如果存在当前键值则数量加一 shannonEnt=0.0#初始化香农熵 for key in labelCounts:#计算香农熵 prob=float(labelCounts[key])/numEntries shannonEnt -= prob*math.log(prob,2) return shannonEnt #初始化数据集 def createDataSet(): dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']] labels=['no surfacing','flippers'] return dataSet,labels def filematrix(filename): data = [] labels = [] fr=open(filename) arrayLines=fr.readlines() for line in arrayLines: tokens=line.strip().split(' ') data.append([float(tk) for tk in tokens[:-1]]) labels.append(tokens[-1]) return data,labels data,labels=filematrix("highandweight.txt") #print calcShannonEnt(data) #print data myDat,labels=createDataSet() print (calcShannonEnt(myDat)) #划分数据集 ''' dataSet:待划分的数据集 axis:划分数据集的特征 value:需要返回的特征的值 ''' def splitDataSet(dataSet,axis,value): retDataSet=[] for featVec in dataSet: if featVec[axis]==value: reducedFeatVec=featVec[:axis] reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSet print splitDataSet(myDat,0,1) #选择最好的数据集划分方式 def chooseBestFeatureToSplit(dataSet): numFeatures=len(dataSet[0])-1#统计特征数量 baseEntropy=calcShannonEnt(dataSet)#计算初始香农熵 bestInfoGain=0.0 bestFeature=-1 for i in range(numFeatures): featList=[example[i] for example in dataSet]#特征值提取 uniqueVals=set(featList)#去除重复特征值 newEntropy=0.0 for value in uniqueVals: subDataSet=splitDataSet(dataSet,i,value) prob=len(subDataSet)/float(len(dataSet)) newEntropy+=prob*calcShannonEnt(subDataSet)#计算新的香农熵 infoGain=baseEntropy-newEntropy if(infoGain>bestInfoGain): bestInfoGain=infoGain bestFeature=i return bestFeature print chooseBestFeatureToSplit(myDat) #分类名称出现最多的种类 def majorityCnt(classList): classCount={} for vote in classList: if vote not in classCount.keys(): classCount[vote]=0 classCount[vote]+=1 sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0] #决策树的创建 ''' dataSet:数据集 labels:标签集 ''' def createTree(dataSet,labels): classList=[example[-1] for example in dataSet]#统计种类标签 if classList.count(classList[0])==len(classList):#所有标签完全相同 return classList[0] if len(dataSet[0])==1:#遍历完所有标签 return majorityCnt(classList) bestFeat=chooseBestFeatureToSplit(dataSet) bestFeatLabel=labels[bestFeat] myTree={bestFeatLabel:{}} del (labels[bestFeat]) featValues=[example[bestFeat] for example in dataSet] uniqueVals=set(featValues) for value in uniqueVals: subLabels=labels[:] myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels) return myTree myTree=createTree(myDat,labels) print myTree def classify(inputTree,featLables,testVec): firstStr=inputTree.keys()[0] secondDict=inputTree[firstStr] myDat,featLables= createDataSet() featIndex=featLables.index(firstStr) for key in secondDict.keys(): if testVec[featIndex]==key: if type(secondDict[key]).__name__=='dict': classLabel=classify(secondDict[key],featLables,testVec) else : classLabel=secondDict[key] return classLabel print classify(myTree,labels,[1,1])
决策树算法学习笔记(二)
最新推荐文章于 2022-10-27 13:39:45 发布