这是看机器学习实战和周志华的西瓜书的决策树代码,自己敲下来的代码,加上自己的注释理解。相关数据应该到处都能下载到。可以直接运行代码。
#coding=utf-8
from math import log
import operator
'''信息熵计算函数'''
def calcShannonEnt(dataset):
numSamples=len(dataset)
labelCounts={}
for featVec in dataset:
currentLabel=featVec[-1]
'''if currentLabel not in labelCounts.keys():
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1'''
#两种写法都可以
labelCounts.setdefault(currentLabel,0)
labelCounts[currentLabel]+=1
Ent=0
for label in labelCounts:
prob=float(labelCounts[label])/numSamples
Ent-=prob*log(prob,2)
return Ent
def createDataSet():
dataSet=[[1,1,"yes"],[1,1,"yes"],[1,0,"no"],[0,1,"no"],[0,1,"no"]]
labels=["no surfacing","flippers"] #属性名
return dataSet,labels
myData,labels=createDataSet()
#myData[0][2]="maybe"
print calcShannonEnt(myData)
'''按照给定特征划分数据集'''
def splitDataSet(dataSet,axis,value):#按照第axis列的属性的属性值value来划分数据集
retDataSet=[]
for featVec in dataSet:
if featVec[axis]==value:
reducedFeatVec=featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:]) #按照属性即某个属性值划分,得到该属性值的所有样本
retDataSet.append(reducedFeatVec) #这样有助于计算每个属性的属性值划分的数据集的信息熵
return retDataSet
print splitDataSet(myData,0,0)
'''利用信息增益选择最优的划分属性'''
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # 属性特征数
baseEnt = calcShannonEnt(dataSet) # 数据集的信息熵
bestFeature = -1 # 最优属性
bestInfoGain = 0.0 # 要求最大信息增益,先初始化为0.0
for i in range(numFeatures):
featList = [example[i] for example in dataSet] # 第i个特征属性的取值,可能有重复的值,下面通过set集合来唯一化
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value) # y用第i个属性的value值进行划分属性
prob = float(len(subDataSet)) / len(dataSet)
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEnt - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
print chooseBestFeatureToSplit(myData)
'''利用增益率选择最优的划分属性'''
def Gain_ratioToChoose(dataset):
#这里可以通过调用信息增益函数来求得相应的增益率
#为了区分开来,另写一个增益率函数
numFeatures=len(dataset[0])-1
beatFeature=-1
baseEnt=calcShannonEnt(dataset)
bestGain_ratio=0.0
for i in range(numFeatures):
featList=[example[i] for example in dataset]
uniqueVals=set(featList)
newEnt=0.0
IV=0.0
for value in uniqueVals:
subDataset=splitDataSet(dataset,i,value)
prob=float(len(subDataset))/len(dataset)
newEnt+=prob*calcShannonEnt(subDataset)
IV-=prob*log(prob,2)
infoGain=baseEnt-newEnt
gain_ratio=infoGain/IV
if gain_ratio>bestGain_ratio:
bestGain_ratio=gain_ratio
bestFeature=i
return bestFeature
print "增益率:",Gain_ratioToChoose(myData)
'''下面采用基尼指数求最优划分属性'''
#先定义基尼值的函数
def GiniValue(dataset): #根据周志华西瓜书的基尼值的求解公式而来
classIndex=len(dataset[0])-1
classList=[example[classIndex] for example in dataset]
uniqueClasses=set(classList)
Gini=0.0
probSqureSum=0.0
for c in uniqueClasses:
prob=float(classList.count(c))/len(classList)
probSqureSum+=prob*prob
Gini=1-probSqureSum
return Gini
'''得到基尼值的计算方法,可直接通过它来计算基尼指数,然后来选择最优划分属性'''
'''选择基尼指数最小的属性,基尼指数越小,基尼值越大,则相同类别的比重多,则数据集纯度越高'''
def Gini_indexToChoose(dataset):
numFeatures=len(dataset[0])-1
minGini_index=10000.0
bestFeature=-1
for i in range(numFeatures):
featList=[example[i] for example in dataset]
uniqueVals=set(featList)
newGini_index=0.0
for value in uniqueVals:
subDataSet=splitDataSet(dataset,i,value)
prob=float(len(subDataSet))/len(dataset)
newGini_index+=prob*GiniValue(subDataSet)
if newGini_index<minGini_index:
minGini_index=newGini_index
bestFeature=i
return bestFeature
print "Gini_index:",Gini_indexToChoose(myData)
'''当所有属性划分完时,但是类标签依然不唯一,则用投票的方法选择频率最高的类别'''
def classVote(classlist):#classlist是所有类别的列表
classCount={}
for vote in classlist:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]+=1 #利用classCount统计类别的次数
sortedclassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
return sortedclassCount[0][0]
print classVote(labels)
'''构建决策树'''
def createTree(dataSet,labels):
classList=[example[-1] for example in dataSet]
#递归停止条件1:当dataset中所有样本的类别是同一个类别时,即都是classList[0]这个类别,则返回这个类别
if classList.count(classList[0])==len(classList):return classList[0]
#递归停止条件2:每次生成的dataset的样本都至少有有一列,
# 即至少有类别这一列,如果某次dataSet[0]这个样本的长度为1,
#那么说明这个数据样本只剩下类别这一列了,也就是所有的属性都判断过了
#仍然不能得到只有一个类别的数据集,则进行投票
if len(dataSet[0])==1:return classVote(labels)
bestfeat=chooseBestFeatureToSplit(dataSet) #寻找最优划分属性的序号
bestfeatLabel=labels[bestfeat] #得到最优属性的名称
myTree={bestfeatLabel:{}} #利用该最优属性作为字典的键,即第一层
del labels[bestfeat] #删除这个属性
featValues=[example[bestfeat] for example in dataSet] #得到该属性的所有属性值
uniqueVals=set(featValues)
#对于每个属性值进行递归生成树
for value in uniqueVals:
subLabels=labels[:]
subdataSet=splitDataSet(dataSet,bestfeat,value) #根据最优属性的属性值划分数据集,进行递归
myTree[bestfeatLabel][value]=createTree(subdataSet,subLabels)
return myTree
#print myTree
tree=createTree(myData,labels)
print tree
str=tree.keys() #返回的是列表
print str
#print tree[str]
'''{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}'''
def classify(inputTree,featLabels,testVec):
firstStr=inputTree.keys()[0] #第一个划分属性名称(字典的第一个键dict.keys()返回的是列表)
secondDict=inputTree[firstStr]
featindex=0
try: #如果不加这个if判断语句,则会出现valueError
featindex=featLabels.index(firstStr) #看这一属性是第几个属性,求其索引
except:
print "error"
#classLabel="yes"
#global classLabel
for key in secondDict.keys():
if testVec[featindex]==key: #看测试样本的第featindex的值是否与key匹配
if type(secondDict[key]).__name__=="dict": #看它是否是字典,则表示可继续匹配,不是的话,则说话后面直接是类别了
result=classify(secondDict[key],featLabels,testVec)
else:
result=secondDict[key]
return result
#print "测试"
#print classify(tree,labels,[1,6])
#由于决策树的构造非常耗时,所以在每次分类的时候应该时候使用
# 已经在磁盘上存储好的决策树来判断这样在处理大型数据集时,比较省时间
#为了解决这个问题,需要使用python模块pickle序列化对象
#序列化对象可以在磁盘上保存对象,并且在需要的时候读取出来,任何对象都可以执行序列化操作
#字典也不例外
import pickle
def storeTree(inputTree,filename): #存储
fw=open(filename,"w")
pickle.dump(inputTree,fw)
fw.close()
def getTree(filename):#提取
fr=open(filename,"r")
return pickle.load(fr)
#storeTree(tree,"tree.txt")
#print getTree("tree.txt")
fileread=open("lenses.txt","r")
lensesdata=[line.strip().split("\t") for line in fileread.readlines()]
print lensesdata
lenselabels=["age","prescript","astigmatic","tearRate"]
lensesTree=createTree(lensesdata,lenselabels)
print lensesTree
#storeTree(lensesTree,"lensesTree.txt")
#print getTree("lensesTree.txt")
#print classify(lensesTree,lenselabels, ['presbyopic', 'myope', 'no', 'reduced'])