#!/usr/bin/python
from math import log
from numpy import *
import operator
import copy
import pickle
def createDataSet():
dataSet =[[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing' , 'flippers']
# print dataSet , labels
return dataSet , labels
def calEntropy(dataSet):
numEntries = len(dataSet)
# print numEntries
labCounts = {}
for vec in dataSet:
currentVec = vec[-1]
if currentVec not in labCounts.keys():
labCounts[currentVec] = 0
labCounts[currentVec] += 1
# print labCounts
entropy = 0.0
for key in labCounts.keys():
prob = float(labCounts[key])/float(numEntries)
# print "prob is : %r" %prob
# print "the log is %f"%(log(prob)/log(2))
entropy -= (prob * log(prob)/log(2))
# print "the entropy is : %f" %entropy
return entropy
def splitDataSet(dataSet , colum , value):
retDataSet = []
for feaVec in dataSet :
if feaVec[colum] == value:
reduceData = feaVec[:colum]
reduceData.extend(feaVec[colum+1:])
retDataSet.append(reduceData)
return retDataSet
def chooseBestFeature(dataSet):
numFeatures = len(dataSet[0])-1
baseEntropy = calEntropy(dataSet)
bestGain = 0.0 ; bestFeature = -1
for i in range(numFeatures):
feaList = [exp[i] for exp in dataSet]
uniqueFea = set(feaList)
# print "unique is ",uniqueFea
newEntropy = 0.0
for value in uniqueFea:
subData = splitDataSet(dataSet , i , value)
prob = len(subData)/float(len(dataSet))
newEntropy += prob * calEntropy(subData)
infoGain = baseEntropy - newEntropy
# print "the %d 's gain is %f " %(i , infoGain)
if( infoGain > bestGain) :
bestGain = infoGain
bestFeature = i
# print "the bestFeature is %d " %bestFeature
return bestFeature
def majorityCnt(classList):
classCount= {}
for valve in classList :
if value not in classCount.keys():
classCount[value] = 0
classCount[value] +=1
sortedClassCount = sorted(classCount.iteritems() , \
key = operator.itemgetter(1) , reverse = True)
return soredClassCount[0][0]
def createTree(dataSet , labels):
classList = [val[-1] for val in dataSet]
if classList.count(classList[0]) == len(dataSet) :
return classList[0]
if len(dataSet[0]) == 1 :
return majorityCnt(classList)
bestFea = chooseBestFeature(dataSet)
bestFeaLabel = labels[bestFea]
myTree = {bestFeaLabel :{}}
del(labels[bestFea])
feaValues = [example[bestFea] for example in dataSet]
uniqueVal = set(feaValues)
# print "unique value " , uniqueVal
for value in uniqueVal :
subLabels = labels[:]
myTree[bestFeaLabel][value] = createTree(\
splitDataSet(dataSet , bestFea , value),\
subLabels)
# print myTree
return myTree
def storeTree(inputTree , filename):
fw = open(filename , 'w')
pickle.dump(inputTree , fw)
fw.close()
def grabTree(filename):
fr = open(filename)
return pickle.load(fr)
def classify(inputTree , featLabels , testVec):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key],\
featLabels , testVec)
else:
classLabel = secondDict[key]
#print "the call back is " , classLabel
return classLabel
def createDataSet(filename):
fr = open(filename)
lenses = [ins.strip().split("\t") for ins in fr.readlines()]
labels = ['age' , 'prescript' , 'astigmatic' , 'tearRate']
return lenses , labels
if __name__ == '__main__':
#dataSet , labels = createDataSet()
#calEntropy(dataSet)
#fea = chooseBestFeature(dataSet)
dataSet , labels = createDataSet("lenses.txt")
featLabels = copy.deepcopy(labels)
myTree=createTree(dataSet , featLabels)
storeTree(myTree , "storeTree.txt")
getTree = grabTree("storeTree.txt")
print getTree
# label = classify(getTree , labels , [1,0,0,1])
# print label
decision_tree (ID3) with python
最新推荐文章于 2021-11-17 23:39:06 发布