用python编写决策树算法_基于ID3决策树算法的实现(Python版)

实例如下:

# -*- coding:utf-8 -*-

from numpy import *

import numpy as np

import pandas as pd

from math import log

import operator

#计算数据集的香农熵

def calcShannonEnt(dataSet):

numEntries=len(dataSet)

labelCounts={}

#给所有可能分类创建字典

for featVec in dataSet:

currentLabel=featVec[-1]

if currentLabel not in labelCounts.keys():

labelCounts[currentLabel]=0

labelCounts[currentLabel]+=1

shannonEnt=0.0

#以2为底数计算香农熵

for key in labelCounts:

prob = float(labelCounts[key])/numEntries

shannonEnt-=prob*log(prob,2)

return shannonEnt

#对离散变量划分数据集,取出该特征取值为value的所有样本

def splitDataSet(dataSet,axis,value):

retDataSet=[]

for featVec in dataSet:

if featVec[axis]==value:

reducedFeatVec=featVec[:axis]

reducedFeatVec.extend(featVec[axis+1:])

retDataSet.append(reducedFeatVec)

return retDataSet

#对连续变量划分数据集,direction规定划分的方向,

#决定是划分出小于value的数据样本还是大于value的数据样本集

def splitContinuousDataSet(dataSet,value,direction):

retDataSet=[]

for featVec in dataSet:

if direction==0:

if featVec[axis]>value:

reducedFeatVec=featVec[:axis]

reducedFeatVec.extend(featVec[axis+1:])

retDataSet.append(reducedFeatVec)

else:

if featVec[axis]<=value:

reducedFeatVec=featVec[:axis]

reducedFeatVec.extend(featVec[axis+1:])

retDataSet.append(reducedFeatVec)

return retDataSet

#选择最好的数据集划分方式

def chooseBestFeatureToSplit(dataSet,labels):

numFeatures=len(dataSet[0])-1

baseEntropy=calcShannonEnt(dataSet)

bestInfoGain=0.0

bestFeature=-1

bestSplitDict={}

for i in range(numFeatures):

featList=[example[i] for example in dataSet]

#对连续型特征进行处理

if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':

#产生n-1个候选划分点

sortfeatList=sorted(featList)

splitList=[]

for j in range(len(sortfeatList)-1):

splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)

bestSplitEntropy=10000

slen=len(splitList)

#求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点

for j in range(slen):

value=splitList[j]

newEntropy=0.0

subDataSet0=splitContinuousDataSet(dataSet,i,0)

subDataSet1=splitContinuousDataSet(dataSet,1)

prob0=len(subDataSet0)/float(len(dataSet))

newEntropy+=prob0*calcShannonEnt(subDataSet0)

prob1=len(subDataSet1)/float(len(dataSet))

newEntropy+=prob1*calcShannonEnt(subDataSet1)

if newEntropy

bestSplitEntropy=newEntropy

bestSplit=j

#用字典记录当前特征的最佳划分点

bestSplitDict[labels[i]]=splitList[bestSplit]

infoGain=baseEntropy-bestSplitEntropy

#对离散型特征进行处理

else:

uniqueVals=set(featList)

newEntropy=0.0

#计算该特征下每种划分的信息熵

for value in uniqueVals:

subDataSet=splitDataSet(dataSet,value)

prob=len(subDataSet)/float(len(dataSet))

newEntropy+=prob*calcShannonEnt(subDataSet)

infoGain=baseEntropy-newEntropy

if infoGain>bestInfoGain:

bestInfoGain=infoGain

bestFeature=i

#若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理

#即是否小于等于bestSplitValue

if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':

bestSplitValue=bestSplitDict[labels[bestFeature]]

labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)

for i in range(shape(dataSet)[0]):

if dataSet[i][bestFeature]<=bestSplitValue:

dataSet[i][bestFeature]=1

else:

dataSet[i][bestFeature]=0

return bestFeature

#特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票

def majorityCnt(classList):

classCount={}

for vote in classList:

if vote not in classCount.keys():

classCount[vote]=0

classCount[vote]+=1

return max(classCount)

#主程序,递归产生决策树

def createTree(dataSet,labels,data_full,labels_full):

classList=[example[-1] for example in dataSet]

if classList.count(classList[0])==len(classList):

return classList[0]

if len(dataSet[0])==1:

return majorityCnt(classList)

bestFeat=chooseBestFeatureToSplit(dataSet,labels)

bestFeatLabel=labels[bestFeat]

myTree={bestFeatLabel:{}}

featValues=[example[bestFeat] for example in dataSet]

uniqueVals=set(featValues)

if type(dataSet[0][bestFeat]).__name__=='str':

currentlabel=labels_full.index(labels[bestFeat])

featValuesFull=[example[currentlabel] for example in data_full]

uniqueValsFull=set(featValuesFull)

del(labels[bestFeat])

#针对bestFeat的每个取值,划分出一个子树。

for value in uniqueVals:

subLabels=labels[:]

if type(dataSet[0][bestFeat]).__name__=='str':

uniqueValsFull.remove(value)

myTree[bestFeatLabel][value]=createTree(splitDataSet\

(dataSet,bestFeat,value),subLabels,labels_full)

if type(dataSet[0][bestFeat]).__name__=='str':

for value in uniqueValsFull:

myTree[bestFeatLabel][value]=majorityCnt(classList)

return myTree

import matplotlib.pyplot as plt

decisionNode=dict(boxstyle="sawtooth",fc="0.8")

leafNode=dict(boxstyle="round4",fc="0.8")

arrow_args=dict(arrowstyle="<-")

#计算树的叶子节点数量

def getNumLeafs(myTree):

numLeafs=0

firstSides = list(myTree.keys())

firstStr=firstSides[0]

secondDict=myTree[firstStr]

for key in secondDict.keys():

if type(secondDict[key]).__name__=='dict':

numLeafs+=getNumLeafs(secondDict[key])

else: numLeafs+=1

return numLeafs

#计算树的最大深度

def getTreeDepth(myTree):

maxDepth=0

firstSides = list(myTree.keys())

firstStr=firstSides[0]

secondDict=myTree[firstStr]

for key in secondDict.keys():

if type(secondDict[key]).__name__=='dict':

thisDepth=1+getTreeDepth(secondDict[key])

else: thisDepth=1

if thisDepth>maxDepth:

maxDepth=thisDepth

return maxDepth

#画节点

def plotNode(nodeTxt,centerPt,parentPt,nodeType):

createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\

xytext=centerPt,textcoords='axes fraction',va="center",ha="center",\

bbox=nodeType,arrowprops=arrow_args)

#画箭头上的文字

def plotMidText(cntrPt,txtString):

lens=len(txtString)

xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002

yMid=(parentPt[1]+cntrPt[1])/2.0

createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree,nodeTxt):

numLeafs=getNumLeafs(myTree)

depth=getTreeDepth(myTree)

firstSides = list(myTree.keys())

firstStr=firstSides[0]

cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)

plotMidText(cntrPt,nodeTxt)

plotNode(firstStr,cntrPt,decisionNode)

secondDict=myTree[firstStr]

plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD

for key in secondDict.keys():

if type(secondDict[key]).__name__=='dict':

plotTree(secondDict[key],str(key))

else:

plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW

plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),leafNode)

plotMidText((plotTree.x0ff,str(key))

plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD

def createPlot(inTree):

fig=plt.figure(1,facecolor='white')

fig.clf()

axprops=dict(xticks=[],yticks=[])

createPlot.ax1=plt.subplot(111,frameon=False,**axprops)

plotTree.totalW=float(getNumLeafs(inTree))

plotTree.totalD=float(getTreeDepth(inTree))

plotTree.x0ff=-0.5/plotTree.totalW

plotTree.y0ff=1.0

plotTree(inTree,(0.5,1.0),'')

plt.show()

df=pd.read_csv('watermelon_4_3.csv')

data=df.values[:,1:].tolist()

data_full=data[:]

labels=df.columns.values[1:-1].tolist()

labels_full=labels[:]

myTree=createTree(data,labels_full)

print(myTree)

createPlot(myTree)

最终结果如下:

{'texture': {'blur': 0,'little_blur': {'touch': {'soft_stick': 1,'hard_smooth': 0}},'distinct': {'density<=0.38149999999999995': {0: 1,1: 0}}}}

得到的决策树如下:

9decac58367aedc43418e4fb28c4a1fa.jpg

参考资料:

《机器学习实战》

《机器学习》周志华著

以上这篇基于ID3决策树算法的实现(Python版)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持编程小技巧。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值