使用ID3算法构建一个决策树
from math import log
import operator
def calShang(dataSet):
numEnteries=len(dataSet)
labelCounts={}
for featVec in dataSet:
currentLabel=featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
shannoEnt=0.0
for key in labelCounts:
prob=float(labelCounts[key])/numEnteries
shannoEnt-=prob*log(prob,2)
return shannoEnt
def splitDataSet(dataSet,axis,value):
retDataSet=[]
for featVec in dataSet:
if featVec[axis]==value:
reducedFeatVec=featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureTosplit(dataSet):
numFeatures=len(dataSet[0])-1
baseEntropy=calShang(dataSet)
maxInfoGain=0.0
minFeature=0
for i in range(numFeatures):
featList=[j[i] for j in dataSet]
uniqueValues=set(featList)
newEntropy=0.0
for value in uniqueValues:
subDataSet=splitDataSet(dataSet,i,value)
prob=len(subDataSet)/float(len(dataSet))
newEntropy+=prob*calShang(subDataSet)
infoGain=baseEntropy-newEntropy
if(infoGain>maxInfoGain):
maxInfoGain=infoGain
minFeature=i
return minFeature
def majorityCnt(classList):
classCount={}
for i in classList:
if i not in classCount.keys():
classCount[i]=0
classCount[i]+=1
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet,labels):
classList=[j[-1] for j in dataSet]
if classList.count(classList[0])==len(classList):
return classList[0]
if len(dataSet[0])==1:
return majorityCnt(classList)
bestFeat=chooseBestFeatureTosplit(dataSet)
bestFeatLabel=labels[bestFeat]
mytree={bestFeatLabel:{}}
del(labels[bestFeat])
featValues=[j[bestFeat] for j in dataSet]
uniqueValues=set(featValues)
for value in uniqueValues:
subLabels=labels[:]
mytree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return mytree
def createDataSet():
dataSet=[["sunny","hot","high","false","no"],
["sunny","hot","high","true","no"],
["overcast","hot","high","false","yes"],
["rainy","mild","high","false","yes"],
["rainy","cool","normal","false","yes"],
["rainy","cool","normal","true","no"],
["overcast","cool","normal","true","yes"],
["sunny","mild","high","false","no"],
["sunny","cool","normal","false","yes"],
["rainy","mild","normal","false","yes"],
["sunny","mild","normal","true","yes"],
["overcast","mild","high","true","yes"],
["overcast","hot","normal","false","yes"],
["rainy","mild","high","true","no"]]
labels=['天气','温度','湿度','风','是否出去玩']
return dataSet,labels
myData,labels=createDataSet()
print(createTree(myData,labels))