机器学习——决策树及绘图(tree&ID3)

决策树(ID3)

参考文章(关于用matplotlib画决策树):

需要用到的模块:math,matplotlib.pyplot,operator
 

1.思路

  • 选择划分数据集属性的方法
  • 找到划分数据集的最佳属性
  • 根据最佳属性划分数据集
  • 若划分后属于同一类别,结束分类;否则继续选择下一最佳属性
  • 直到分完所有属性,按多数表决法分类
  • 根据树绘图
  • 保存分类器,以供下一次使用
     

2.函数

 
生成树

  • 函数1:信息增益计算(熵)
  • 函数2:划分数据集
  • 函数3:选择最优属性
  • 函数4:确定叶子节点类别(多数表决法)
  • 函数5:生成树
  • 函数6:按照生成树给测试集分类
  • 函数7:存储构造好的树
  • 函数8:提取构造好的树
     

画图

  • 确定决策节点和叶节点的样式
  • 函数1:节点绘制
  • 函数2:确定生成树的叶节点数量和层数
  • 函数3:绘制节点之间的信息
  • 函数4:绘制树
  • 函数5:生成图
     
     

3. 生成树

 

'''
ID3算法
当前仅适用于离散型数据
'''


from math import log
import operator

## 信息增益:计算信息熵
def info_shan(dataset):
    label_counts = {}
    num_data = len(dataset)
    # 计算每个分类出现的频率
    for vet in dataset:
        current_label = vet[-1]
        if current_label not in label_counts.keys():
            label_counts[current_label] = 0
        label_counts[current_label] += 1
    shan = 0.0
    for key in label_counts:
        probility = float(label_counts[key])/num_data
        shan -= probility * log(probility,2)
    return shan

## 划分数据集(第二个参数是特征在数据集的位置)
def splitdata(dataset,feature,value):
    sub_dataset = []
    # 将数据集按指定属性的值划分
    for vector in dataset:
        if vector[feature] == value :
            # 提取符合要求的属性且不包含该指定属性
            fea_vector = vector[:feature]
            fea_vector.extend(vector[feature+1:])
            sub_dataset.append(fea_vector)
    return sub_dataset

## 选择划分数据集的最佳属性(信息增益)
def best_feature(dataset):
    # 默认最后一列是标签
    num_feature = len(dataset[0])-1
    shan = info_shan(dataset)
    best_info_gain = 0.0
    best_feature = -1
    # 遍历数据集中所有特征
    for i in range(num_feature):
        fea_list = [x[i] for x in dataset]
        # 建立集合,得到不重复的值的集合
        fea_vals = set(fea_list)
        fea_shan = 0.0
        # 根据值的不同划分子集
        for value in fea_vals:
            sub_dataset = splitdata(dataset,i,value)
            probility = len(sub_dataset)/float(len(dataset))
            fea_shan += probility * info_shan(sub_dataset)
        info_gain = shan - fea_shan
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_feature = i
    return best_feature

## 确定叶子节点的类别(多数表决法)            
def majority_vote(classlist):
    class_count={}
    # 遍历每个类别,得到频数
    for vote in classlist:
        if vote not in class_count.keys():
            class_count[vote] = 0
        class_count[vote]+=1
    # 找到最高频数的类别
    class_label = sorted(class_count.items(),key=operator.itemgetter(1),\
                         reverse=True)
    return class_label[0][0]

# 生成树    
def create_tree(dataset,labels):
    classlist = [x[-1] for x in dataset]
    # 若类别相同,则停止划分
    if classlist.count(classlist[0]) == len(classlist):
        return classlist[0]
    # 若所有特征均被遍历,返回出现次数最多的类别
    if len(dataset[0]) == 1:
        return majority_vote(classlist)
    # 选择最优属性
    best_fea = best_feature(dataset)
    best_fea_label = labels[best_fea]
    mytree = {best_fea_label:{}}
    del(labels[best_fea])
    fea_val = [x[best_fea] for x in dataset]
    vals = set(fea_val)
    for val in vals:
        sublabels = labels[:]
        mytree[best_fea_label][val] = create_tree(splitdata(dataset,best_fea,val),\
              sublabels)
    return mytree

## 分类   
def classify(input_tree,feature_label,testvec):
    keys = list(input_tree.keys())
    firststr = keys[0]
    second_dict = input_tree[firststr]
    # 确定进行判断的属性
    fea_index = feature_label.index(firststr)
    for key in second_dict.keys():
        # 确定测试集在该属性下的值,遍历所有节点进行匹配
        if testvec[fea_index] == key:
            # 若匹配到的节点有下一层树,则递归,根据下一个属性分类
            if type(second_dict[key]).__name__ == 'dict':
                classlabel = classify(second_dict[key],feature_label,testvec)
            else:
                # 若匹配到的节点已经是叶节点,则确定分类
                classlabel = second_dict[key]
    return classlabel

## 存储构造好的树(便于下一次直接用树分类)
def store_tree(tree,filename):
    import pickle
    # python3不接受二进制文件,需要用二进制写入模式(用with语句更简洁)
    with open(filename,'wb') as filewrite:
        pickle.dump(tree,filewrite)

## 获取存储的树
def grab_tree(filename):
    import pickle
    with open(filename,'rb') as fileread:
        return pickle.load(fileread)

 
 

4. 画图

 

import matplotlib.pyplot as plt

# 决策节点的样式:boxstyle是文本框形状,sawtooth是锯齿型,fc是边框粗细
decision_node = dict(boxstyle = 'sawtooth',fc='0.8')

# 叶节点的样式
leaf_node = dict(boxstyle = 'round4',fc='0.8')

# 箭头属性
arrow_args = dict(arrowstyle = '<-')

## 绘制带箭头的注解
'''
nodetxt-文本
centerpt-文本中心点
dotpt-箭头指向文本的点
nodetype判断是叶节点还是决策节点的样式
'''
def plot_node(nodetxt,centerpt,dotpt,nodetype):
    # annotate()函数用于标注文字
    '''
    nodetxt-标注内容
    xy-箭头指向的点的坐标
    xycoords-指向的点的坐标属性(以子绘图区左下角为参考,单位是百分比)
    xytext-标注内容的坐标
    textcoords-文本的坐标属性
    va/ha-点的位置(va:top, bottom, center, baseline;ha:right,center,left)
    bbox-内容增加外框
    arrowprops-箭头参数(字典形式)
    '''
    create_plot.ax1.annotate(nodetxt,xy=dotpt,xycoords='axes fraction',\
                            xytext=centerpt,textcoords='axes fraction',\
                            va='center',ha='center',bbox=nodetype,
                            arrowprops=arrow_args)


# 判断叶节点的个数和决策树层数
def get_leafnum_depth(mytree):
    leafnum = 0
    depth_all = 0
    keys = list(mytree.keys())
    firststr = keys[0]
    second_dict = mytree[firststr]
    for key in second_dict.keys():
        if type(second_dict[key]).__name__ == 'dict':
            leafnum0,depth0 = get_leafnum_depth(second_dict[key])
            leafnum += leafnum0
            depth = 1+ depth0
        else:
            leafnum += 1
            depth = 1
        if  (depth > depth_all): depth_all = depth
    return leafnum,depth_all

# 节点之间的线的中间补充文本信息
def plot_midtext(centerpt,dotpt,stringtxt):
    xmid = (dotpt[0] - centerpt[0])/2.0 + centerpt[0]
    ymid = (dotpt[1] - centerpt[1])/2.0 + centerpt[1]
    # 在ax1图上标注信息(x-位置,y-位置,文本)
    create_plot.ax1.text(xmid,ymid,stringtxt)

## 树的绘制
'''将叶子节点数作为份数平均切分整个x轴,将树的层数作为份数平均切分y轴长度'''    
def plot_tree(mytree,dotpt,nodetxt):
    leafnum,depth = get_leafnum_depth(mytree)
    keys = list(mytree.keys())
    firststr = keys[0]
    # 确定节点的位置(xoff(上一个节点的位置)+偏移量,yoff不变)
    '''
    确定节点位置时每次需确定当前层有几个叶子节点,
    这层所有叶子节点所占的总距离即为float(leafnum)/plotTree.totalw
    而当前节点的位置即为其所有叶子节点所占距离的中间即为float(leafnum)/2.0/plotTree.totalw
    由于开始plotTree.xoff赋值左移了半个距离,因此还需加上1/2/plotTree.totalw
    '''
    centerpt = (plot_tree.xoff + (1.0 + float(leafnum))/2.0/plot_tree.totalw,\
                                  plot_tree.yoff)
    plot_midtext(centerpt,dotpt,nodetxt)
    plot_node(firststr,centerpt,dotpt,decision_node)
    second_dict = mytree[firststr]
    # 确定同层下一个节点y轴位置(偏移量为1/层数)
    plot_tree.yoff = plot_tree.yoff - 1.0/plot_tree.totald
    for key in second_dict.keys():
        if type(second_dict[key]).__name__ == 'dict':
            # 若该节点是父节点,则递归自身
            plot_tree(second_dict[key],centerpt,str(key))
        else:
            # 若是叶节点,直接绘制
            plot_tree.xoff = plot_tree.xoff + 1.0/plot_tree.totalw
            plot_node(second_dict[key],(plot_tree.xoff,plot_tree.yoff),\
                      centerpt,leaf_node)
            plot_midtext((plot_tree.xoff,plot_tree.yoff),centerpt,str(key))
    plot_tree.yoff = plot_tree.yoff + 1.0/plot_tree.totald

    
# 生成图
def create_plot(intree):
    # 新建一个绘图窗口,facecolor-区域背景色
    figure1 = plt.figure(num=1,facecolor = 'white')
    # 清空绘图区
    figure1.clf()
    # 设置坐标轴刻度标签
    axprops = dict(xticks=[],yticks=[])
    # 创建该函数的属性(ax1)
    '''创建一个新的子图,图绘制在第一块,frameno-是否绘制图像边框'''
    create_plot.ax1 = plt.subplot(111,frameon = False,**axprops)
    # totalw-叶节点数;totald-树的层数
    leafnum,depth = get_leafnum_depth(intree)
    plot_tree.totalw = float(leafnum)
    plot_tree.totald = float(depth)
    # xoff,yoff表示第一个节点位置
    '''1/叶节点个数=两个节点相隔的距离,
    *(-1/2)表示初始位置向左偏移0.5个距离(为了图形好看)'''
    plot_tree.xoff = (1/plot_tree.totalw)*(-1/2)
    plot_tree.yoff = 1
    # 构建决策树,第一个节点在(0.5,1.0)
    plot_tree(intree,(0.5,1.0),'')
    plt.show()    
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值