决策树ID3实例(python2.7代码)

@决策树原理实例(python代码实现)

决策树ID3实例(python2.7代码)

实现ID3决策树,最终能根据我们输入的决策数据以图片的方式得出决策树图片。
在这里插入图片描述

数据库

使用数据库MySQL读取出数据:

def execute():
    connection = pymysql.connect(host='172.24.28.255',
                                 port=3308,
                                 user='LongQin',
                                 password='password',
                                 db='baseball',
                                 charset='utf8mb4',
                                 cursorclass=pymysql.cursors.DictCursor)#
    cursor=connection.cursor()
 
    cursor.execute("select Outlook,Temperature,Humidity,Windy,Play from ball")#
    result = cursor.fetchall()
    dataSet = []
    for l in result:
        lab = list(l.keys())
    label=[]
    for x in range(len(lab)):
            allsymptom = lab[x]
            strsymptom = str(allsymptom).replace('u\'','\'')  
            strsymptom.decode("unicode-escape")
            label.append(strsymptom)
            if lab[x]== 'Play':
                playnum=x
    for r in result:
        lab = list(r.keys())
        rowlist = r.values()
        set=[]
        for m in rowlist:
            all_symptom = m
            str_symptom = str(all_symptom).replace('u\'','\'')  
            str_symptom.decode("unicode-escape")
            set.append(str_symptom)
        exchange = set[playnum]
        set[playnum]=set[len(lab)-1]
        set[len(lab)-1]=exchange
        dataSet.append(set)
    exchange = label[playnum]
    label[playnum]=label[len(lab)-1]
    label[len(lab)-1]=exchange
    label.remove('Play');
    print label
    from pprint import pprint
    pprint(dataSet)    
    labels = copy.copy(label)
    label1 = copy.copy(label)
    label2 = copy.copy(label)#属性
    return dataSet,labels,label1,label2

在这里插入图片描述
数据库读取的数据;

pygraphviz

使用了pygraphviz 以生成图片,就是最上方那张,下载和使用的教程可参考这个
https://blog.csdn.net/qq_35603331/article/details/81591949

整体代码

Python2.7可用
注意数据库pygraphviz要装好,不然就得舍弃相对应部分代码,生成树的过程中label数组会减少,所以要另一只树的话注意要新建一个数组,代码中的label1就是这个道理;

# -*- coding: utf-8 -*
#!/usr/bin/env python
import math 
import operator
import xlrd
import pygraphviz as pyg
import pymysql.cursors
import copy


def createDataSet():    # 创造数据
    file = 'C:/Users/Administrator/Desktop/20.xlsx'
    wb = xlrd.open_workbook(filename = file)
    ws = wb.sheet_by_name('Sheet1')#棒球数据
    dataSet = []
    for r in range (1,ws.nrows):
        col = []
        for c in range(ws.ncols):
            if ws.cell(r,c).value==0:
                col.append('False')
            elif ws.cell(r,c).value==1:
                col.append('True')
            else:
                col.append(ws.cell(r,c).value.encode('utf-8'))
        print(col)
        dataSet.append(col)        
    from pprint import pprint
    #pprint(dataSet) 
    labels = ['Outlook','Temperature','Humidity','Windy']  #四个属性
    label1 = ['Outlook','Temperature','Humidity','Windy']  #四个属性
    label2 = ['Outlook','Temperature','Humidity','Windy']  #四个属性
    return dataSet,labels,label1,label2

def execute():
    connection = pymysql.connect(host='172.24.28.255',
                                 port=3308,
                                 user='LongQin',
                                 password='password',
                                 db='baseball',
                                 charset='utf8mb4',
                                 cursorclass=pymysql.cursors.DictCursor)#
    cursor=connection.cursor()
 
    cursor.execute("select Outlook,Temperature,Humidity,Windy,Play from ball")#
    result = cursor.fetchall()
    dataSet = []
    for l in result:
        lab = list(l.keys())
    label=[]
    for x in range(len(lab)):
            allsymptom = lab[x]
            strsymptom = str(allsymptom).replace('u\'','\'')  
            strsymptom.decode("unicode-escape")
            label.append(strsymptom)
            if lab[x]== 'Play':
                playnum=x
    for r in result:
        lab = list(r.keys())
        rowlist = r.values()
        set=[]
        for m in rowlist:
            all_symptom = m
            str_symptom = str(all_symptom).replace('u\'','\'')  
            str_symptom.decode("unicode-escape")
            set.append(str_symptom)
        exchange = set[playnum]
        set[playnum]=set[len(lab)-1]
        set[len(lab)-1]=exchange
        dataSet.append(set)
    exchange = label[playnum]
    label[playnum]=label[len(lab)-1]
    label[len(lab)-1]=exchange
    label.remove('Play');
    print label
    from pprint import pprint
    pprint(dataSet)    
    labels = copy.copy(label)
    label1 = copy.copy(label)
    label2 = copy.copy(label)#属性
    return dataSet,labels,label1,label2

def reckonEsum(dataSet): #熵
    num=len(dataSet)#纵向,事务个数
    Elable={}
    for each in dataSet:
        if each[-1] not in Elable.keys():
            Elable[each[-1]]=0
        Elable[each[-1]]+= 1#最后一个
    EIsm=0.0
    for lab in Elable:
        pi = float(Elable[lab])/num     
        EIsm -= pi * math.log(pi,2)##有用
    return EIsm

def splitData(dataSet,i,value): # 按某个特征分类后的剩余属性事务,行
    reSetData=[]
    for shiWu in dataSet:#事务
        if shiWu[i]==value:
            FeatVec =shiWu[:i]
            FeatVec.extend(shiWu[i+1:])
            reSetData.append(FeatVec)
    return reSetData

def chooseBest(dataSet): 
    num = (len(dataSet[0])-1)#除去是否的属性数
    baseny = reckonEsum(dataSet)  # 熵
    bestGain = 0
    bestattr = -1#记录属性列数
    for i in range(num):#第i个属性的类别
        attribute = [example[i] for example in dataSet]#第i列
        u = set(attribute)#i列类别
        EOutlook = 0
        for v in u:
            danDataSet = splitData(dataSet,i,v)
            Smj =float(len(danDataSet))/float(len(dataSet))
            EOutlook +=Smj*reckonEsum(danDataSet)  # 按特征分类的信息熵
        infoGain = baseny - EOutlook  # 信息增益
        if (infoGain>bestGain):   # 信息增益越大
            bestGain=infoGain
    return bestattr
    
def drawTree(first,two):
    A.add_edge(first,two)
    

def endtree():
    A.graph_attr['epsilon']='0.01'
    A.write('fooOld.dot')
    A.layout('dot') # layout with dot
    A.draw('b.png') # write to file little tree
    
def endtree_c():
    A.graph_attr['epsilon']='0.01'
    A.write('fooOld.dot')
    A.layout('dot') # layout with dot
    A.draw('c.png') # write to file little tree with color
    

def createTree(dataSet,labels):
    List=[example[-1] for example in dataSet]#最后是否的属性列值
    if List.count(List[0])==len(List):
        return List[0]
    besti=chooseBest(dataSet) #选择最优特征
    bestattrLabel=labels[besti]
    cidianTree={bestattrLabel:{}} #分类结果以字典形式保存,深度搜索形式
    del(labels[besti])#删掉选定属性
    attributeLei=[example[besti] for example in dataSet]
    interattru=set(attributeLei)#属性分类
    for i in interattru:
        subLabels=labels[:]#删掉后的属性
        cidianTree[bestattrLabel][i]=createTree(splitData(dataSet,besti,i),subLabels)
    return cidianTree

     
def travel(adict):
     if isinstance(adict,dict) :
        for key,value in adict.items():        
            if isinstance(value,dict) :
                alist=(key)#,list(value.keys())
                b=list(value.keys())
                for w in range(len(b)):
                    drawTree(alist,b[w])
            else:
                drawTree(key,value)
            travel(value)
            
def caculate(adict,chosen,zidi):
     if isinstance(adict,dict) :
        for key,value in adict.items():        
            if isinstance(value,dict) :
                alist=(key)#,list(value.keys())
                b=list(value.keys())
                for n in range(len(b)):
                    for m,v in chosen.items():
                        if (alist == m and v==b[n]):
                            zidi[alist]=v
                            if value[v]=='Yes' or value[v]=='No':
                                return value[v]
                            caculate(value[v],chosen,zidi)              

def biaohong(chosen,result):
    for h,v in chosen.items():
        A.add_node(h,color='red')
        A.add_node(v,color='red')
    if result=='None':
        A.add_node('No',color='red')
    else:
        A.add_node(result,color='red')
            
    
if __name__=='__main__':
    A=pyg.AGraph(directed=True,strict=True)
    chosen = {'Outlook':'sunny','Temperature':'hot','Humidity':'high','Windy':'False'}#'No'
    #chosen = {'Outlook':'overcast','Temperature':'cool','Humidity':'normal','Windy':'True'}#'Yes'
    #chosen = {'Outlook':'rain','Temperature':'mild','Humidity':'high', 'Windy':'True'}#'No'
    dataSet, labels,label1,label2=execute()  # 创造示列数据
    print(label1)
    zidi=dict()
    #print(createTree(dataSet, labels))  # 输出决策树模型结果
    travel(createTree(dataSet, label1))    
    result = caculate(createTree(dataSet, labels),chosen,zidi)
    result = str(result)
    biaohong(zidi,result)
    endtree_c()

参考资料

第一次写博文引用的下面的这篇博文的代码,不知道格式对不对
https://blog.csdn.net/csqazwsxedc/article/details/65697652

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值