@决策树原理实例(python代码实现)
决策树ID3实例(python2.7代码)
实现ID3决策树,最终能根据我们输入的决策数据以图片的方式得出决策树图片。
数据库
使用数据库MySQL读取出数据:
def execute():
connection = pymysql.connect(host='172.24.28.255',
port=3308,
user='LongQin',
password='password',
db='baseball',
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor)#
cursor=connection.cursor()
cursor.execute("select Outlook,Temperature,Humidity,Windy,Play from ball")#
result = cursor.fetchall()
dataSet = []
for l in result:
lab = list(l.keys())
label=[]
for x in range(len(lab)):
allsymptom = lab[x]
strsymptom = str(allsymptom).replace('u\'','\'')
strsymptom.decode("unicode-escape")
label.append(strsymptom)
if lab[x]== 'Play':
playnum=x
for r in result:
lab = list(r.keys())
rowlist = r.values()
set=[]
for m in rowlist:
all_symptom = m
str_symptom = str(all_symptom).replace('u\'','\'')
str_symptom.decode("unicode-escape")
set.append(str_symptom)
exchange = set[playnum]
set[playnum]=set[len(lab)-1]
set[len(lab)-1]=exchange
dataSet.append(set)
exchange = label[playnum]
label[playnum]=label[len(lab)-1]
label[len(lab)-1]=exchange
label.remove('Play');
print label
from pprint import pprint
pprint(dataSet)
labels = copy.copy(label)
label1 = copy.copy(label)
label2 = copy.copy(label)#属性
return dataSet,labels,label1,label2
数据库读取的数据;
pygraphviz
使用了pygraphviz 以生成图片,就是最上方那张,下载和使用的教程可参考这个
https://blog.csdn.net/qq_35603331/article/details/81591949
整体代码
Python2.7可用
注意数据库pygraphviz要装好,不然就得舍弃相对应部分代码,生成树的过程中label数组会减少,所以要另一只树的话注意要新建一个数组,代码中的label1就是这个道理;
# -*- coding: utf-8 -*
#!/usr/bin/env python
import math
import operator
import xlrd
import pygraphviz as pyg
import pymysql.cursors
import copy
def createDataSet(): # 创造数据
file = 'C:/Users/Administrator/Desktop/20.xlsx'
wb = xlrd.open_workbook(filename = file)
ws = wb.sheet_by_name('Sheet1')#棒球数据
dataSet = []
for r in range (1,ws.nrows):
col = []
for c in range(ws.ncols):
if ws.cell(r,c).value==0:
col.append('False')
elif ws.cell(r,c).value==1:
col.append('True')
else:
col.append(ws.cell(r,c).value.encode('utf-8'))
print(col)
dataSet.append(col)
from pprint import pprint
#pprint(dataSet)
labels = ['Outlook','Temperature','Humidity','Windy'] #四个属性
label1 = ['Outlook','Temperature','Humidity','Windy'] #四个属性
label2 = ['Outlook','Temperature','Humidity','Windy'] #四个属性
return dataSet,labels,label1,label2
def execute():
connection = pymysql.connect(host='172.24.28.255',
port=3308,
user='LongQin',
password='password',
db='baseball',
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor)#
cursor=connection.cursor()
cursor.execute("select Outlook,Temperature,Humidity,Windy,Play from ball")#
result = cursor.fetchall()
dataSet = []
for l in result:
lab = list(l.keys())
label=[]
for x in range(len(lab)):
allsymptom = lab[x]
strsymptom = str(allsymptom).replace('u\'','\'')
strsymptom.decode("unicode-escape")
label.append(strsymptom)
if lab[x]== 'Play':
playnum=x
for r in result:
lab = list(r.keys())
rowlist = r.values()
set=[]
for m in rowlist:
all_symptom = m
str_symptom = str(all_symptom).replace('u\'','\'')
str_symptom.decode("unicode-escape")
set.append(str_symptom)
exchange = set[playnum]
set[playnum]=set[len(lab)-1]
set[len(lab)-1]=exchange
dataSet.append(set)
exchange = label[playnum]
label[playnum]=label[len(lab)-1]
label[len(lab)-1]=exchange
label.remove('Play');
print label
from pprint import pprint
pprint(dataSet)
labels = copy.copy(label)
label1 = copy.copy(label)
label2 = copy.copy(label)#属性
return dataSet,labels,label1,label2
def reckonEsum(dataSet): #熵
num=len(dataSet)#纵向,事务个数
Elable={}
for each in dataSet:
if each[-1] not in Elable.keys():
Elable[each[-1]]=0
Elable[each[-1]]+= 1#最后一个
EIsm=0.0
for lab in Elable:
pi = float(Elable[lab])/num
EIsm -= pi * math.log(pi,2)##有用
return EIsm
def splitData(dataSet,i,value): # 按某个特征分类后的剩余属性事务,行
reSetData=[]
for shiWu in dataSet:#事务
if shiWu[i]==value:
FeatVec =shiWu[:i]
FeatVec.extend(shiWu[i+1:])
reSetData.append(FeatVec)
return reSetData
def chooseBest(dataSet):
num = (len(dataSet[0])-1)#除去是否的属性数
baseny = reckonEsum(dataSet) # 熵
bestGain = 0
bestattr = -1#记录属性列数
for i in range(num):#第i个属性的类别
attribute = [example[i] for example in dataSet]#第i列
u = set(attribute)#i列类别
EOutlook = 0
for v in u:
danDataSet = splitData(dataSet,i,v)
Smj =float(len(danDataSet))/float(len(dataSet))
EOutlook +=Smj*reckonEsum(danDataSet) # 按特征分类的信息熵
infoGain = baseny - EOutlook # 信息增益
if (infoGain>bestGain): # 信息增益越大
bestGain=infoGain
return bestattr
def drawTree(first,two):
A.add_edge(first,two)
def endtree():
A.graph_attr['epsilon']='0.01'
A.write('fooOld.dot')
A.layout('dot') # layout with dot
A.draw('b.png') # write to file little tree
def endtree_c():
A.graph_attr['epsilon']='0.01'
A.write('fooOld.dot')
A.layout('dot') # layout with dot
A.draw('c.png') # write to file little tree with color
def createTree(dataSet,labels):
List=[example[-1] for example in dataSet]#最后是否的属性列值
if List.count(List[0])==len(List):
return List[0]
besti=chooseBest(dataSet) #选择最优特征
bestattrLabel=labels[besti]
cidianTree={bestattrLabel:{}} #分类结果以字典形式保存,深度搜索形式
del(labels[besti])#删掉选定属性
attributeLei=[example[besti] for example in dataSet]
interattru=set(attributeLei)#属性分类
for i in interattru:
subLabels=labels[:]#删掉后的属性
cidianTree[bestattrLabel][i]=createTree(splitData(dataSet,besti,i),subLabels)
return cidianTree
def travel(adict):
if isinstance(adict,dict) :
for key,value in adict.items():
if isinstance(value,dict) :
alist=(key)#,list(value.keys())
b=list(value.keys())
for w in range(len(b)):
drawTree(alist,b[w])
else:
drawTree(key,value)
travel(value)
def caculate(adict,chosen,zidi):
if isinstance(adict,dict) :
for key,value in adict.items():
if isinstance(value,dict) :
alist=(key)#,list(value.keys())
b=list(value.keys())
for n in range(len(b)):
for m,v in chosen.items():
if (alist == m and v==b[n]):
zidi[alist]=v
if value[v]=='Yes' or value[v]=='No':
return value[v]
caculate(value[v],chosen,zidi)
def biaohong(chosen,result):
for h,v in chosen.items():
A.add_node(h,color='red')
A.add_node(v,color='red')
if result=='None':
A.add_node('No',color='red')
else:
A.add_node(result,color='red')
if __name__=='__main__':
A=pyg.AGraph(directed=True,strict=True)
chosen = {'Outlook':'sunny','Temperature':'hot','Humidity':'high','Windy':'False'}#'No'
#chosen = {'Outlook':'overcast','Temperature':'cool','Humidity':'normal','Windy':'True'}#'Yes'
#chosen = {'Outlook':'rain','Temperature':'mild','Humidity':'high', 'Windy':'True'}#'No'
dataSet, labels,label1,label2=execute() # 创造示列数据
print(label1)
zidi=dict()
#print(createTree(dataSet, labels)) # 输出决策树模型结果
travel(createTree(dataSet, label1))
result = caculate(createTree(dataSet, labels),chosen,zidi)
result = str(result)
biaohong(zidi,result)
endtree_c()
参考资料
第一次写博文引用的下面的这篇博文的代码,不知道格式对不对
https://blog.csdn.net/csqazwsxedc/article/details/65697652