ML作业3——ID3决策树算法

数据集:


ID3算法:

ID3算法是以信息熵和信息增益为衡量标准,从而实现对数据的归纳分类的一种算法。

首先,ID3算法需要解决的问题是如何选择特征作为划分数据集的标准。在ID3算法中,选择信息增益最大的属性作为当前的特征对数据集分类。

其次,ID3算法需要解决的问题是如何判断划分的结束。分为两种情况,第一种为划分出来的类属于同一个类,第二种为已经没有属性可供再分了。此时就结束了。

通过递归的方式,得到ID3决策树模型,它是局部最优的且仅输出单个目标,并近似的喜好“最短的”决策树。

信息熵与信息增益:

信息熵(information entropy):

信息熵是度量样本集合纯度最常用的一种指标。假定当前的样本集合D中第k类样本所占比例为,则D的信息熵定义为:

由信息熵定义可知熵越大,训练集D中的样本的类别越不纯。

信息增益(information gain):

假定离散属性aV个可能的取值,若使用a来对样本集D进行划分,于是可计算出用属性a对样本集D进行划分所获得的“信息增益”,公式如下:

由信息增益的公式可知,信息增益越大,意味着使用属性a来进行划分所获得的“纯度提升”越大。

算法流程:

创建决策树的根节点root

若所有的样本均为正例,返回单个根节点root并标记为“+

若所有的样本均为反例,返回单个根节点root并标记为“-

若属性值为空,返回单个根节点root并标记为大部分样本的标签

其他情况下:

    选取当前例子下最大信息增益的属性A作为分类属性,令A指向根节点

              For each A中的属性值Vi

        增加一个关于Vi的分支

                            For each符合A=Vi的样本ExampleVi

            若ExampleVi为空

                                                    投票选出满足其父节点训练集的大部分标签

                                         否则

                                                    递归产生决策树

代码:

# -*- coding: utf-8 -*-
"""
Created on Thu Apr  5 13:36:19 2018

@author: 安颖
"""
import numpy as np
from math import log
import draw_tree

#定义属性值
Outlook = ["Sunny","Overcast","Rain"]
Temperature = ["Hot","Mild","Cool"]
Humidity = ["High","Normal"]
Wind = ["Strong","Weak"]
Attri=[]
Attri.append(Outlook)
Attri.append(Temperature)
Attri.append(Humidity)
Attri.append(Wind)

#数据集
my_data = []
with open('data.txt', 'r') as data_txt:
    data = data_txt.readlines()
    for line in data:
        temp = line.split(',')
        my_data.append([temp[0],temp[1],temp[2],temp[3],int(temp[4])])
    my_data = np.array(my_data)
data_txt.close()

#记录不同属性值的label比值
def count_label(data):
    #初始化属性组
    Anum = [[[0,0],[0,0],[0,0]],[[0,0],[0,0],[0,0]],[[0,0],[0,0]],[[0,0],[0,0]]]
    #进行计数
    for i in range(len(data)):
        #print(str(i)+":"+str(data[i]))
        for j in range(len(Attri)):
            for k in range(len(Attri[j])):
                if data[i][j]== Attri[j][k]:
                    if int(data[i][4]) == 1:
                        Anum[j][k][0] += 1
                    else:
                        Anum[j][k][1] += 1
                    break
    return Anum


#计算信息增益
def cal_gain(data,label,targe_attr):
    #取类标签集合
    classList=[example[4] for example in data]
    #计算信息熵
    ent_s = 0.0
    if len(data)!=0 and classList.count('0') != 0:
        prob = (float)(classList.count('0')/len(data))
        ent_s -= prob*log(prob,2)
    if len(data)!=0 and classList.count('1') != 0:
        prob = (float)(classList.count('1')/len(data))
        ent_s -= prob*log(prob,2)
    #计算增益
    gain = 0.0
    for i in range(len(Attri[targe_attr])):
        ent = 0.0
        sum = label[targe_attr][i][0] + label[targe_attr][i][1]
        if sum == 0:
            continue
        for j in range(2):
            prob = (float)(label[targe_attr][i][j]/sum)
            #p=0时认为熵为0
            if prob == 0.0:
                continue
            else:
                ent -= prob*log(prob,2)
        gain -= ent*(sum/len(data))
    gain += ent_s
    return gain

#选择信息增益最大的属性
def choose_attr(data,label,attr):
    max = [0]*len(label)
    #print("data:"+str(len(data)))
    for i in range(len(label)):
        if attr[i] != 'selectedAtt':
            max[i] = cal_gain(data,label,i)
        else:
            max[i] = -1
    temp = np.argsort(max)
    #全部选完
    if attr[temp[3]] == 'selectedAtt':
        return -1
    else:
        return temp[3]

#判断属性值是否为空
def pdnull(attr):
    flag = 0
    for i in attr:
        if i != 'selectedAtt' :
            flag = 1
            break
    return flag

#投票选出大多数标签
def vote(classList):
    if classList.count('1') > classList.count('0'):
        flag = '1'
    else:
        flag = '0'
    return flag

#ID3构建树算法
def id3_tree(examples,targe_attr,attr):
    #取类标签集合
    classList=[example[4] for example in examples]
    #若全为正例,返回正标签
    if classList.count('1')==len(classList):
        myTree = '1'
        return myTree
    #若全为反例,返回负标签
    if classList.count('0')==len(classList):
        myTree = '0'
        return myTree
    #若属性值为空,返回单个根节点root并标记为大部分样本的标签
    if len(attr)==0:
        myTree = vote(classList)
        return myTree
    #否则进行递归建树    
    else:
        AttributeLabel = attr[targe_attr]
        myTree={AttributeLabel:{}}
        #已用过的属性进行标记
        #attr[targe_attr] = 'selectedAtt'
        #该属性值列表
        featValues=[example[targe_attr] for example in examples]
        for i in range(len(Attri[targe_attr])):
            #若没有属性则投票获得大多数票数的标签值
            attr_value = Attri[targe_attr][i]
            #若没有当前属性值的数据则进行投票
            if featValues.count(attr_value)==0 :
                myTree[AttributeLabel][attr_value] = vote(classList)
            else:
                sub_examples=[]
                for j in range(len(examples)):
                    if featValues[j]==Attri[targe_attr][i]:
                        sub_examples.append(examples[j])
                #若子属性还有数据值,递归
                #if len(sub_examples) > 0:
                sub_attr = attr[:]
                #已用过的属性进行标记
                sub_attr[targe_attr] = 'selectedAtt'
                sub_label = count_label(sub_examples)
                sub_targe_attr = choose_attr(sub_examples,sub_label,sub_attr)
                if sub_targe_attr != -1:
                    myTree[AttributeLabel][attr_value] = id3_tree(sub_examples,sub_targe_attr,sub_attr)
        return myTree
        
        
if  __name__    ==   '__main__':
    #Att数组记录不同属性
    Att = ["Outlook","Temperature","Humidity","Wind"]
    #选择第一个合适的属性值
    label = count_label(my_data)
    targe_attr = choose_attr(my_data,label,Att)
    tree = id3_tree(my_data,targe_attr,Att)
    print(tree)
    draw_tree.createPlot(tree)
    
    #test2
    data = ['Rain','Cool','Normal','Strong','0'] 
    print("加入新样本"+str(data)+"后:")
    new_data = my_data[:,:]
    #加入新样本
    new_data = np.row_stack((new_data,data))
    #选择第一个合适的属性值
    new_label = count_label(new_data)
    targe_attr = choose_attr(new_data,new_label,Att)
    new_tree = id3_tree(my_data,targe_attr,Att)
    print(new_tree)
    draw_tree.createPlot(new_tree)
    
    

画决策树:

# -*- coding: utf-8 -*-
"""
Created on Thu Apr  5 11:59:17 2018

@author: 安颖
"""

import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")
 
 
#计算树的叶子节点数量
def getNumLeafs(myTree):
  numLeafs=0
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  secondDict=myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      numLeafs+=getNumLeafs(secondDict[key])
    else: numLeafs+=1
  return numLeafs
 
#计算树的最大深度
def getTreeDepth(myTree):
  maxDepth=0
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  secondDict=myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      thisDepth=1+getTreeDepth(secondDict[key])
    else: thisDepth=1
    if thisDepth>maxDepth:
      maxDepth=thisDepth
  return maxDepth
 
#画节点
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
  createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
  xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\
  bbox=nodeType,arrowprops=arrow_args)
 
#画箭头上的文字
def plotMidText(cntrPt,parentPt,txtString):
  lens=len(txtString)
  xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002
  yMid=(parentPt[1]+cntrPt[1])/2.0
  createPlot.ax1.text(xMid,yMid,txtString)
 
def plotTree(myTree,parentPt,nodeTxt):
  numLeafs=getNumLeafs(myTree)
  depth=getTreeDepth(myTree)
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
  plotMidText(cntrPt,parentPt,nodeTxt)
  plotNode(firstStr,cntrPt,parentPt,decisionNode)
  secondDict=myTree[firstStr]
  plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      plotTree(secondDict[key],cntrPt,str(key))
    else:
      plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW
      plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
      plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
  plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD
 
def createPlot(inTree):
  fig=plt.figure(1,facecolor='white')
  fig.clf()
  axprops=dict(xticks=[],yticks=[])
  createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
  plotTree.totalW=float(getNumLeafs(inTree))
  plotTree.totalD=float(getTreeDepth(inTree))
  plotTree.x0ff=-0.5/plotTree.totalW
  plotTree.y0ff=1.0
  plotTree(inTree,(0.5,1.0),'')
  plt.show()

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值