from math import log
def createDataSet(): %创建数据集
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing','flippers']
return dataSet, labels
def calcShannonEnt(dataSet): %计算香农熵
numEntries = len(dataSet) %numEntries =5
labelCounts = {} %空的字典
for featVec in dataSet:
currentLabel = featVec[-1] %比如第一次循环,currentLabel = 'yes'
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries %以第一次循环为例,prob= float(2)/5
shannonEnt -= prob * log(prob,2) %求和
return shannonEnt
在python命令提示符下输入以下内容:
reload(trees)
myDat,labels = trees.createDataSet()
改为了
import trees
myDat,labels = trees.createDataSet()
print(myDat)
myDat[0][-1]='maybe' %第0行,最后一个元素改为'maybe'
def splitDataSet(dataSet, axis, value): %划分数据集,输入数据集、分类特征、特征值
retDataSet = []
for featVec in dataSet: %比如axis==0 value==1
if featVec[axis] == value: % dataSet第一个元素满足条件 reducedFeatVec = featVec[:axis] %reducedFeatVec =featVec[:0]=[]
reducedFeatVec.extend(featVec[axis+1:]) %reducedFeatVec.extend(featVec[0+1:])=[1,'yes']
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 #dateset的最后一列作为标签 baseEntropy = calcShannonEnt(dataSet) #计算原始香农熵 bestInfoGain = 0.0; bestFeature = -1 for i in range(numFeatures): #numFeatures=2,i=0,1 featList = [example[i] for example in dataSet] #数据集中i=0,1 列的特征值featList=[1, 1, 1, 0, 0],[1, 1, 0, 1, 1] uniqueVals = set(featList) #得到独立不重复的值uniqueVals ={0,1} newEntropy = 0.0 for value in uniqueVals: #valiue=0,1俩次循环 subDataSet = splitDataSet(dataSet, i, value) #划分数据集 prob = len(subDataSet)/float(len(dataSet)) #计算概率 newEntropy += prob * calcShannonEnt(subDataSet) #这里的prob算是每个香农熵的权值,求和 infoGain = baseEntropy - newEntropy if (infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i #熵减少越多,信息增益越大 return bestFeature
def majorityCnt(classList): #多数投票 classCount={} for vote in classList: if vote not in classCount.keys(): #类标签计数 classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) #排序 return sortedClassCount[0][0]