Python机器学习算法实践——决策树(ID3)

一.决策树模型

决策树是一种通过对特征属性属性分类对样本进行分类的树形结构,包括边和三类节点

  • 根节点:决策树的起源,进行分类的第一个特征属性,只有出边没有入边;
  • 内部节点:正在进行分类的特征属性,有一条入边,至少有一条出边;
  • 叶节点:分类结束的特征属性,有入边,没有出边;

决策树模型
上面就是一个简单的二叉决策树

二.决策树学习

决策树学习本质就是对样本总结一个分类规则,既能对已知样本进行合理的拟合分类,也能对未知样本进行正确的分类预测,这其中就有两个关键:

  • 选择哪个特征属性进行分类?
  • 到什么时候停止分类?

三.决策树算法

  • 特征属性选择

在ID3中,特征属性的选择是由目标函数决定的,目标函数代表的是特征属性的混乱程度(也就是特征属性越混乱越不好分类,该特征属性的分类顺序越靠后),这个目标函数就是信息增益,信息增益是由熵计算出来的:
熵:Entropy(t)=−∑kp(ck|t)logp(ck|t)
信息增益:
Δ=H(c)−∑i=1nN(ai)NH(c|ai)
=H(c)−∑inp(ai)H(c|ai)
=H(c)−H(c|A)
其中H(c)表示父节点的熵值,H(c|A)表示该父节点下特征属性的熵的加权和

  • 决策树生成

1.如果节点满足停止分裂的条件,则将其设为叶节点;
2.如果节点不满足停止分裂的条件,则选择信息增益最大的属性分裂;
3.重复1-2动作直到分类完成
C4.5与ID3的区别就是属性的选择条件为信息增益比

四.算法实现

# -*- coding=utf-8 -*-
"""
author:xuwf
created:2017-2-10
purpose:实现决策树算法C4.5
"""

from numpy import *

'''装载文本'''
def load():
    data=loadtxt('d:/myProject/data/destree.txt',delimiter=',',dtype=str)
    return data

'''熵计算'''
def calcShan(dataset):
    rows=dataset.shape[0]
    shannon=0.0 #熵值
    dataCount=dataInfo(dataset)
    for i in dataCount:
        temp=dataCount[i]/float(rows)   #计算概率
        shannon-=temp*(log2(dataCount[i])-log2(rows))   #计算熵
    return shannon

'''统计数据'''
def dataInfo(dataset):
    info={}
    for i in dataset:
        info[i]=info.get(i,0)+1 #计算每个数据出现的次数(dict.get方法获取)
    return info

'''最大增益比计算'''
def calcGain(dataset):
    rows,cols=dataset.shape
    shan=calcShan(dataset[:,cols-1])    #计算总熵
    shanPro=zeros(cols-1)   #每个属性的熵增益比
    for i in range(cols-1):
         dataCount=dataInfo(dataset[:,i])
         for j in dataCount.keys():
            temp=dataCount[j]/float(rows)
            shanPro[i]+=temp*calcShan(dataset[dataset[:,i]==j,cols-1])
         shanPro[i]=(shan-shanPro[i])/shan
    return argsort(-shanPro)[0]

'''计算分支熵'''
def calcShanBranch(dataset,i):
    rows,cols=dataset.shape
    shanBranch={}
    dataCount=dataInfo(dataset[:,i])
    for j in dataCount.keys():
        shanBranch[j]=calcShan(dataset[dataset[:,i]==j,cols-1])
    return shanBranch

'''建树干'''
def buildTree(dataset,label,tree):
    sortRatio=calcGain(dataset) #找出熵增益最大的那个下标,从这个下标开始分类
    shanBranch=calcShanBranch(dataset,sortRatio)    #计算分支的熵,判断是否为同一类
    tree={label[sortRatio]:{}}
    for i in shanBranch.keys():
        dataNew=dataset[dataset[:,sortRatio]==i,:]
        dataNew=dataNew[:,dataset[0,:]!=dataset[0,sortRatio]]
        labelNew=label[label!=label[sortRatio]]
        if shanBranch[i]==0:
            #如果分支的熵为0,说明该分支下的类别为同一类,该分支分类完成
            tree[label[sortRatio]][i]=dataset[dataset[:,sortRatio]==i,-1][0]
        else:
            if label.shape[0]==2:
                #如果到了最后的一组就返回数据
                return tree
            else:
                #没有到最后一组就继续循环
                tree[label[sortRatio]][i]=buildTree(dataNew,labelNew,tree)
    return tree

'''主函数'''
def main():
    data=load()
    label=data[0,:]
    data=data[1:,:]
    tree={}
    tree=buildTree(data,label,tree)
    return tree

if __name__=='__main__':
    main()

五.画图生成决策树

# -*- coding=utf-8 -*-

"""
author:wfxu
create:2017-02-22
purpose:实现plt画图
"""

import matplotlib.pyplot as plt
import destree

def cyc(tree,x,y,coord,ax): 
  for i in tree:
    x1=x    #起点坐标
    y1=y
    ax.text(x1,y1,i,fontsize=20,va='center',ha='center',bbox=dict(facecolor='red', alpha=0.5))  #起点
    x2=x-1  #终点坐标
    y2=y-1
    for j in tree[i]:
        if (x2,y2) in coord:
            x2+=1
        ax.text((x1+x2)/2.0,(y1+y2)/2.0,j,fontsize=10,va='center',ha='center')  #中间点
        ax.arrow(x1,y1,x2-x1,y2-y1,width=0.05,length_includes_head=True,overhang=0.3,head_length=0.1)
        #print (j,(x1+x2)/2.0,(y1+y2)/2.0),(tree[i][j],x2,y2)
        x2+=1
        if tree[i][j] in ('yes','no'):
            #判断有没有属性点重合的,如果有就向右移动一位
            coord.append((x2-1,y2))
            ax.text(x2-1,y2,tree[i][j],fontsize=20,va='center',ha='center',bbox=dict(facecolor='blue', alpha=0.5))  #终点
        else:
            cyc(tree[i][j],x2-1,y2,coord,ax)

def main():
    fig=plt.figure()
    ax=fig.add_subplot(111)
    ax.plot((5,1.5),(3.5,0.5),'b+')
    tree=destree.main()
    x=3
    y=3
    coord=[]
    cyc(tree,x,y,coord,ax)
    fig.show()

if __name__=='__main__':
    main()

决策树效果
用plt画的图:
这是程序运行后的决策树

六.总结

1.这次用的数据比较单一,所以最后的结果都是“纯净的”,如果样本空间很大,那么最后的分裂停止条件就应该是一个阈值(比如80%以上为yes就视为分裂完成),这样的结果就是预测结果会有偏差
2.如果预测的结果又偏差,就需要进行剪枝,这里并没有剪枝,希望下次可以拿一个大的样本进行测试
3.这里画图是在坐标轴上画的,看起来很压抑,希望下次有方法可以在白画布上画出决策树
如果大家有什么更好的方法或意见请分享至下面的留言!

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值