决策树——id3算法的python代码实现

使用python语言实现了决策树算法——id3算法,废话不多说,直接贴代码

import math
train={#定义训练集
    1:{'outlook':'sunny','temp':'hot','hum':'high','wind':'weak','play':'no'},
    2:{'outlook':'sunny','temp':'hot','hum':'high','wind':'strong','play':'no'},
    3:{'outlook':'overcast','temp':'hot','hum':'high','wind':'weak','play':'yes'},
    4:{'outlook':'rain','temp':'mild','hum':'high','wind':'weak','play':'yes'},
    5:{'outlook':'rain','temp':'cool','hum':'normal','wind':'weak','play':'yes'},
    6:{'outlook':'rain','temp':'cool','hum':'normal','wind':'strong','play':'no'},
    7:{'outlook':'overcast','temp':'cool','hum':'normal','wind':'strong','play':'yes'},
    8:{'outlook':'sunny','temp':'mild','hum':'high','wind':'weak','play':'no'},
    9:{'outlook':'sunny','temp':'cool','hum':'normal','wind':'weak','play':'yes'},
    10:{'outlook':'rain','temp':'mild','hum':'normal','wind':'weak','play':'yes'},
    11:{'outlook':'sunny','temp':'mild','hum':'normal','wind':'strong','play':'yes'},
    12:{'outlook':'overcast','temp':'mild','hum':'high','wind':'strong','play':'yes'},
    13:{'outlook':'overcast','temp':'hot','hum':'normal','wind':'weak','play':'yes'},
    14:{'outlook':'rain','temp':'mild','hum':'high','wind':'strong','play':'no'},
    }

def info(train):#定义信息量计算方法,传入训练集
    total,totalzheng,totalfu,info=0,0,0,0
    for key in train.keys():#计算训练样本中yes和no的个数
        total+=1
        if train[key]['play']=='yes':
            totalzheng+=1
        elif train[key]['play']=='no':
            totalfu+=1
    if totalfu==0 or totalzheng==0:
        return [total,totalzheng,totalfu,info]#如果全为yes或者全为no,则信息为0
    else:
        bili1=totalzheng/(totalzheng+totalfu)
        bili2=totalfu/(totalzheng+totalfu)
        info=bili1*math.log2(bili1)+bili2*math.log2(bili2)#计算公式为正数的比例×log2(正数的比例)
        return [total,totalzheng,totalfu,round(info,3)*-1]

def parttrain(train,targetattr,mainattr):#定义分离数组的方法,传入需要分离的训练集,需要分离出来的属性,和该属性所属的字段
    returndict={}#定义返回字典
    for key in train.keys():
        if train[key][mainattr]==targetattr:#如果该条的相应的属性值等于目标属性
            returndict[key]=train[key]
    return returndict

def attrset(train,attr):#求该属性在该训练集下的集合
    resset=[]
    for key in train.keys():
        resset.append(train[key][attr])#直接加入训练集中该属性下的属性值
    resset=set(resset)#去除重复值
    return resset 

class Tree():
    def __init__(self,root):#初始化函数,定义节点值和结点的孩子字典
        self.root=root
        self.child={}
    def addchild(self,attr,dict):#传入属性值和字典,构建孩子字典
        self.child[attr]=dict
    def show(self):#返回根节点
        a={}
        a[self.root]=self.child#将孩子字典赋值给根节点
        return a

def maxinfo(train,attrs):#定义求该训练集下attrs属性列表中信息增益最大的属性的方法
    maxattr=''
    maxnum=0
    for attr in attrs:#循环所有的属性
        attrtibutes=attrset(train,attr)#求该属性下的属性值
        infoall=info(train)#求该属性的信息量
        for shuxing in attrtibutes:#对于每个属性值
            attrtrain=parttrain(train,shuxing,attr)#先分理处该属性下的训练集
            shuxinginfo=info(attrtrain)#求该训练集信息量
            infoall[3]-=(shuxinginfo[0]/infoall[0])*shuxinginfo[3]#信息增益计算公式
        if infoall[3]>=maxnum:#找到拥有最大信息增益的属性
            maxnum=round(infoall[3],3)
            maxattr=attr
    return maxattr
    
def id3(examples,target,attributes):#id3方法
    root=Tree(target)#定义根节点
    examplesnum=info(examples)#先求训练集下的信息量
    if examplesnum[1]!=0 and examplesnum[2]==0:#如果训练集下yes不为零然后no为零,则全为yes,返回
        root.addchild(target,'yes')
    elif examplesnum[1]==0 and examplesnum[2]!=0:
        root.addchild(target,'no')
    elif len(attributes)==0:
        if examplesnum[1]>=examplesnum[2]:
            root.addchild(target,'yes')
        else:
            root.addchild(target,'no')
    else:
        attrs=attrset(examples,target)#定义属性集
        attributes.remove(target)
        for attr in attrs:
            nextexample=parttrain(examples,attr,target)
            target2=maxinfo(nextexample,attributes)
            xunhuanattrs=[]
            for i in range(0,len(attributes)):
                xunhuanattrs.append(attributes[i])
            root.addchild(attr,id3(nextexample,target2,xunhuanattrs))
    return root.show()

attrs=['outlook','temp','hum','wind']
target=maxinfo(train,attrs)
a=id3(train,target,attrs)
print(a)



  • 3
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值