DecisionTree.py
from math import log
#决策树利用了信息论中用熵来表示数据分类的混乱程度,一个集合中Shannon熵越高该集合越混乱
#因此选择划分属性的时候,先计算当前集合的熵,再分别计算利用每个属性划分后集合的熵,
# 最后与当前集合熵相差最大的熵所对应的属性即当前划分属性
#一直最优划分,直到每个节点都为纯节点后或者所有属性都划分完了为止,此时建树完成
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2)
return shannonEnt
def splitDataSet(dataSet, axis, value):
retDataset = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataset.append(reducedFeatVec)
return retDataset
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i ,value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
classCount = {}
for vote in classCount:
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=lambda d:d[1], reverse = True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
#print(len(classList))
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
print(bestFeat)
print(bestFeatLabel)
myTree = {bestFeatLabel:{}}#用字典建立新的结点
tmplabels = labels[:]
del(tmplabels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
print(uniqueVals)
for value in uniqueVals:
subLabels = tmplabels[:]
#当前字典结点添加新结点
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__=='dict':
classLabel = classify(secondDict[key],featLabels,testVec)
else: classLabel = secondDict[key]
return classLabel
def creatDataSet():
dataSet = [
[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']
]
labels = ['no surfacing','flippers']
return dataSet, labels
myData, labels = creatDataSet()
print("1",labels)
myTree = createTree(myData, labels)
print("2",labels)
print(myTree) #result = {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
test1 = classify(myTree,labels,[1,0])
test2 = classify(myTree,labels,[1,1])
print("test1: ",test1) # test1: no
print("test2: ",test2) # test2: yes
简单应用,得到分类
import ch2.DecisionTree as dTree
fr = open("lenses.txt")
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = dTree.createTree(lenses,lensesLabels)
print(lensesTree)
得到决策树后,通过将其序列化,可以在需要使用的时候再将其读入内存,即不需再次训练,节省时间。
#任何对象都可以使用pickle序列化
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
test.py 学习过程中的对于python的一些练习,和一些知识点。
#python语言在函数中传递的是列表的引用,在函数内部对列表对象的修改将会影响该列表对象的整个生存周期
import matplotlib.pyplot as plt
'''
#决策树注解绘图
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt,
xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot():
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon=False)
plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
createPlot()
'''
'''
#测试extend和append的不同
a = [1,2,3]
b = [3,4,5]
a.append(b)
print(a)
#使用append结果为[1, 2, 3, [3, 4, 5]],即将b整个列表当成一个元素添加,添加后有4个元素
c = [1,2,3]
d = [3,4,5]
c.extend(d)
print(c)
#使用extend结果为[1, 2, 3, 3, 4, 5],添加后有6个 元素
'''
'''
#测试dict_keys
#python3中dict的keys(), values(), items()返回的都是迭代器,用list转化为列表可用索引调用得到每个key
d= {'a':{'d':2},'b':1,'c':{}}
print(list(d.keys()))
print(list(d.keys())[0])
'''