#-*- coding=utf8 -*-
import numpy as np
from math import log
def createTree(dataSet,labels):
classList=[s[-1] for s in dataSet]
if classList.count(classList[0])==len(classList):
return classList[0] #叶节点
elif len(dataSet[0])==1:
return majorityLabel(dataSet) #叶节点
bestFeat=chooseBestFeature(dataSet)
# print labels
# print bestFeat
bestfeatlabel=labels[bestFeat]
myTree={bestfeatlabel:{}}
del(labels[bestFeat])
featvalues=[example[bestFeat] for example in dataSet]
uniqueValue=set(featvalues)
for value in uniqueValue:
subLabels = labels[:] #这里必须写成subLabels = labels[:]而不能是subLabels = labels
myTree[bestfeatlabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
def splitDataSet(dataSet,feat,value):
subDateSet=[]
for line in dataSet:
if line[feat]==value:
lineSubData=line[:feat]
lineSubData.extend(line[feat+1:])
subDateSet.append(lineSubData)
return subDateSet
def majorityLabel(classList):
counts={}
unique=set(classList)
for i in unique:
counts[i]=classList.count(i)
sortedlabel = sorted(counts.iteritems(), key=lambda asd: asd[1], reverse=False)
return sortedlabel[0][0]
def calcEntropy(dataSet):
classList=[example[-1] for example in dataSet]
numberData = len(classList)
uniqueLabel=set(classList)
counts=[]
for i in uniqueLabel:
counts.append(classList.count(i))
entropy=0.0
for i in counts:
prop=i*1.0/numberData
entropy-=prop*log(prop,2)
# print entropy
return entropy
def chooseBestFeature(dataSet):
numFeatures=len(dataSet[0])-1
baseEntropy=calcEntropy(dataSet)
maxEntropy=0.0;bestFeature=-1
for i in range(numFeatures):
featureValues=[example[i] for example in dataSet]
uniqueValue=set(featureValues)
newEntropy=0.0
for value in uniqueValue:
subDataSet=splitDataSet(dataSet,i,value)
prob=len(subDataSet)/float(len(dataSet))
newEntropy+=prob*calcEntropy(subDataSet)
infoGain=baseEntropy-newEntropy
if infoGain>maxEntropy:
maxEntropy=infoGain
bestFeature=i
return bestFeature
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing','flippers']
#
# storeTree(createTree(dataSet,labels),'mytree.txt')
fr=open('Ch03\lenses.txt')
lenses=[inst.strip().split('\t') for inst in fr.readlines()]
lenseLabel=['age','pres','asti','tear','ui','iu']
lensetree=createTree(lenses,lenseLabel)
# grabTree('mytree.txt')
print lensetree
决策树实现
最新推荐文章于 2024-03-29 13:08:11 发布