from numpy import *
from math import log
import operator
def calcShannonEnt(dataSet):
num=len(dataSet)
labelCount={}
for data in dataSet:
currentLabel = data[-1]
#if currentLabel not in labelCount.keys():
# labelCount[currentLabel]=0
#labelCount[currentLabel]+=1
labelCount[currentLabel]=labelCount.get(currentLabel,0)+1
shannonEnt=0.0
for key in labelCount:
p=float(labelCount[key])/num
shannonEnt -= p*log(p,2)
return shannonEnt
def createDataSet():
dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
labels=['no surfacing','flippers']
return dataSet,labels
def splitDataSet(dataSet, axis, value):
retDataSet=[]
for featureVec in dataSet:
if featureVec[axis]==value:
temp=featureVec[:axis]
temp.extend(featureVec[axis+1:])
retDataSet.append(temp)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
EntD=calcShannonEnt(dataSet)
feaNo=len(dataSet[0])-1
bestFeature=-1
bestEntD=-1
for i in range(feaNo):
feati=[example[i] for example in dataSet]
uniqueVals=set(feati)
subEnt=0.0
for value in uniqueVals:
subDataSet=splitDataSet(dataSet, i, value)
p=len(subDataSet)/float(len(dataSet))
subEnt+=p*calcShannonEnt(subDataSet)
newEnt=EntD-subEnt
if newEnt > bestEntD:
bestEntD=newEnt
bestFeature=i
return bestFeature
def majorityCnt(classList):
classCount={}
for item in classList:
classCount[item]=classCount.get(item)+1
sortedClass=sorted(classCount.iteritems,key=operator.itemgetter(1),reverse=True)
return sortedClass[0][0]
def createTree(dataSet,labels):
classList=[item[-1] for item in dataSet]
if len(set(classList))==1:
return classList[0]
if len(dataSet[0])==1:
return majorityCnt(classList)
bestFeature=chooseBestFeatureToSplit(dataSet)
bestLabel=labels[bestFeature]
bFeatureItems=[example[bestFeature] for example in dataSet]
uniqueVals=set(bFeatureItems)
trees={bestLabel:{}}
del(labels[bestFeature])
for value in uniqueVals:
subDataSet=splitDataSet(dataSet,bestFeature, value)
subLabels=labels[:]
trees[bestLabel][value]=createTree(subDataSet,subLabels)
return trees
def classify(inputTree,featLabels,testVec):
firstStr=inputTree.keys()[0]
secTree=inputTree[firstStr]
try:
featIndex=featLabels.index(firstStr)
except ValueError:
print("List does not contain value")
for key in secTree.keys():
if testVec[featIndex]==key:
if type(secTree[key]).__name__ == 'dict':
result=classify(secTree[key],featLabels,testVec)
else:
result=secTree[key]
return result
def storeTree(inputTree,filename):
import pickle
fw=open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr=open(filename)
return pickle.load(fr)
dataSet,label=createDataSet()
trees=createTree(dataSet,label)
print trees
dataSet,label=createDataSet()
r=classify(trees,label,[0,1])
print r
'''
fr=open('lenses.txt')
lenses=[inst.strip().split('\t') for inst in fr.readlines()]
lenseLabel=['age', 'prescript', 'astigmatic', 'tearRate']
trees=createTree(lenses,lenseLabel)
lenseLabel=['age', 'prescript', 'astigmatic', 'tearRate']
result=classify(trees,lenseLabel,['pre','myope','no','normal'])
print result
'''