使用Python 实现决策树算法,主要是对ID3算法的实现
决策树算法的理论知识,这里不做介绍,网上很多,可以看该博客的,写的很详细,也有例子
虽然也有很多人也使用Python 写了,但是这里我自己写了,是加深自己对决策树的理解,这里做个记录吧。
# -*- coding: utf-8 -*-
"""
Created on Tue May 5 10:02:36 2020
@author: Administrator
"""
# =============================================================================
# 构建决策树的整体思路如下:
#第一步:计算信息熵;
#第二步:根据特征来划分数据集,并计算每个特征对应的信息熵,选择信息熵最优的,重新划分数据集;
#第三步:重复第二步的内容,直到判断条件:①每个分支下的所有实例都具有相同的分类;②程序遍历完所有划分数据集的属性。
# 若满足条件①,那说明可以很好的分类,若不能满足条件①,即数据集已经处理了所有的属性,但类标签依然不是唯一,
# 通常采用表决的方法决定改叶子节点的分类,即输出分类数量最多的类别
# =============================================================================
from math import log
import operator
import pickle
def majorityCnt(classList):#表决的方法决定改叶子节点的分类
classCount={}
for vote in classList:
if vote not in classCount.keys(): classCount[vote]=0
classCount+=1
sorteClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reversed=True)
return sorteClassCount[0][0]
def calcShannonEnt(dataSet):#计算信息熵
num=len(dataSet)
labelCounts={}
for featVec in dataSet:
currentLabel=featVec[-1] #取类别,即最后一列
if currentLabel not in labelCounts.keys(): #判断其他的类别是否在类别列表中
labelCounts[currentLabel]=0
labelCounts[currentLabel] +=1 #统计每个类别的个数
shannonEnt=0.0
for key in labelCounts:
prob=float(labelCounts[key])/num
shannonEnt -= prob*log(prob,2)
return shannonEnt
def splitDataSet(dataSet,axis,value): #重新划分数据集
retDataSet=[]
for featVec in dataSet:
if featVec[axis]==value:
reducedFeatVec=featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
numFeatures=len(dataSet[0])-1 #基本信息熵在按照最后一个特征先计算好了,因此这里减去1
baseEntropy=calcShannonEnt(dataSet)#基本信息熵
bestInfoGain=0.0
bestFeature=-1
for i in range(numFeatures):
featList=[example[i] for example in dataSet] #列出第i列的每一行特征
uniqueVals=set(featList) #set函数可以去除重复项
newEntropy=0.0
for value in uniqueVals:
subDataSet=splitDataSet(dataSet,i,value) #基于第i特征和value特征重新划分数据集
prob=len(subDataSet)/float(len(dataSet)) #这个for循环是计算第i特征划分所得到的信息熵
newEntropy +=prob*calcShannonEnt(subDataSet)
infoGain=baseEntropy-newEntropy #
if(infoGain>bestInfoGain):#得到最好的信息增益
bestInfoGain=infoGain
bestFeature=i
return bestFeature
def createTree(dataSet,labels):
classList=[example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):#类别完全相同则停止继续划分
return classList[0]
if len(dataSet)==1: #遍历完所有特征时返回出现次数最多的
return majorityCnt(classList)
bestFeat=chooseBestFeatureToSplit(dataSet)
bestFeatLabel=labels[bestFeat]
myTree={bestFeatLabel:{}}
del(labels[bestFeat]) #删除最优特征对于的标签
featValues=[example[bestFeat] for example in dataSet]
uniqueVals=set(featValues)
for value in uniqueVals:
subLabels=labels[:]
myTree[bestFeatLabel][value] =createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
def createDataSet():
dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
labels=['no surfacing','flippers']
return dataSet,labels
#预测使用
def classify(inputTree,featLabels,testVec):
firstStr=list(inputTree.keys())[0]
secondDict=inputTree[firstStr]
featIndex=featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] ==key:
if type(secondDict[key]).__name__=='dict':
classLabel=classify(secondDict[key],featLabels,testVec)
else:
classLabel=secondDict[key]
return classLabel
#决策树存储
def storeTree(inputTree,filename):
fw=open(filename,'wb')
pickle.dump(inputTree,fw)
fw.close()
#还原树结构
def grabTree(filename):
fr=open(filename,'rb')
return pickle.load(fr)
if __name__ == "__main__":
dataSet,label=createDataSet()
label1=label.copy() #因为createTree()使用了del函数,因此这里将标签复制一份
myTree=createTree(dataSet,label1)
storeTree(myTree,'classfiyTree.txt')
tree=grabTree('classfiyTree.txt')
print(tree)
print(classify(myTree,label,[1,1]))