python实现ID3
具体的决策树原理在此就不再赘述,可自行百度或者看我之前写的:https://blog.csdn.net/weixin_38273255/article/details/88752468
在这里主要列出我使用的代码,和一些学习时候的心得。
心得
代码在网上其实都有一大片一大片的,但是自己觉得不太符合自己想要的要求,所以就找了网上的代码,并做了一些修改。很多地方肯定会有毛病,还望见谅。
我想要的是能够对连续型数据做处理,但是ID3是对离散型数据做处理,网上的示例代码也是如此,之前在学习原理的时候看到说可以对连续数据做离散处理:
- 对连续数据排序
- 将排好的序列等分(让每个区间落入的数据量基本一致)
- 重新对区间赋值
- 离散化完成
可能这个也存在问题,但是就这么实现了。
代码
- 导入数据
- 对数据离散化
- 创建决策树
- 绘制决策树图像
- 输入测试集测试
下下面就列出我的代码,代码主要是对iris数据集做出了处理,其他数据集可自行修改。
main.py
# -*- coding: utf-8 -*-
import creat as ct
import draw as dr
import yanzheng as yz
from sklearn import datasets
import numpy as np
import lisanhua as lsh
import random
#加载iris数据集
iris = datasets.load_iris()
all_data = iris.data[:,:]
all_target = iris.target[:]
labels = iris.feature_names[:]
#常量定义
n = 150#数据集总数
m = int(n*2/3)#创建用的数据量
q = 4#数据维度
l = 7#离散化个数
#对数据离散化
a = []
all_data,a = lsh.lsh(all_data,l)
#将target和数据合并
all_data = all_data.tolist()
all_target = all_target.tolist()
for i in range(len(all_target)):
all_data[i].append(all_target[i])
#将数据打乱
random.shuffle(all_data)
#创建决策树数据集
cj_data = all_data[:m]
#创建决策树
myTree=ct.createTree(cj_data,labels)
#创建验证数据集
all_data = np.array(all_data)#转化为numpy
yz_target = np.array(all_data[m:n,q:q+1])
yz_data = np.array(all_data[m:n,:q])
yz_labels = np.array(iris.feature_names[:])
#验证决策树正确率
yz_shu = yz.yanzheng(myTree,yz_data,yz_labels,yz_target)
yz_bfb = float(yz_shu)/(n-m)
#结果反馈
print(myTree)
print(yz_shu)
print(yz_bfb)
dr.createPlot(myTree)
lisanhua.py
import numpy as np
def lsh(data,num):
a = []
for i in range(len(data[0])):
b = []
data1 = data[:,i]
l = len(data1)
data1.sort()
for k in range(num):
b.append(data1[int(k*l/num)])
for j in range(len(data)):
if data[j,i] >= b[-1]:
data[j,i] = b[-1]
continue
for q in range(1,num):
if data[j,i] < b[q] and data[j,i] >= b[q-1]:
data[j,i] = b[q-1]
a.append(b)
return data,a
creat.py
import math
import operator
def calcShannonEnt(dataset):
numEntries = len(dataset)
labelCounts = {}
for featVec in dataset:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] +=1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob*math.log(prob, 2)
return shannonEnt
def CreateDataSet():
dataset = [[1, 1, 'yes' ],
[1, 1, 'yes' ],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataset, labels
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
numberFeatures = len(dataSet[0])-1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0;
bestFeature = -1;
for i in range(numberFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy =0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
classCount ={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]=1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0])==len(classList):
return classList[0]
if len(dataSet[0])==1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
draw.py
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")
#计算树的叶子节点数量
def getNumLeafs(myTree):
numLeafs=0
firstStr=myTree.keys()[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
numLeafs+=getNumLeafs(secondDict[key])
else: numLeafs+=1
return numLeafs
#计算树的最大深度
def getTreeDepth(myTree):
maxDepth=0
firstStr=myTree.keys()[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth=1+getTreeDepth(secondDict[key])
else: thisDepth=1
if thisDepth>maxDepth:
maxDepth=thisDepth
return maxDepth
#画节点
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\
bbox=nodeType,arrowprops=arrow_args)
#画箭头上的文字
def plotMidText(cntrPt,parentPt,txtString):
lens=len(txtString)
xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002
yMid=(parentPt[1]+cntrPt[1])/2.0
createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
numLeafs=getNumLeafs(myTree)
depth=getTreeDepth(myTree)
firstStr=myTree.keys()[0]
cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=myTree[firstStr]
plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD
def createPlot(inTree):
fig=plt.figure(1,facecolor='white')
fig.clf()
axprops=dict(xticks=[],yticks=[])
createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
plotTree.totalW=float(getNumLeafs(inTree))
plotTree.totalD=float(getTreeDepth(inTree))
plotTree.x0ff=-0.5/plotTree.totalW
plotTree.y0ff=1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()
yanzheng.py
import numpy as np
def one(myTree,data,labels):
if type(myTree) == int:
return myTree
key = myTree.keys()[0]
keys = key.split('<=')
if data[labels == keys[0]] <= float(keys[1]):
return one(myTree[myTree.keys()[0]][1],data,labels)
else:
return one(myTree[myTree.keys()[0]][0],data,labels)
def getResult(myTree,data,labels):
result = []
for elem in data:
result.append(one(myTree,elem,labels))
return result
def yanzheng(myTree,data,labels,target):
count = 0
result = getResult(myTree,data,labels)
for i in range(len(result)):
if(result[i] == target[i]):
count += 1
return count
结果
结果见:
https://blog.csdn.net/weixin_38273255/article/details/88981203