机器学习实验3 基于信息增益率的决策树模型

  • 实验目的

1. 采用C4.5算法建立决策树模型

2. 将每个特征的信息增益率打印输出

3. 输出最后建立的决策树模型结构图(可以用字典结构表示,也可以用matplotlib等图形库绘制出树形图)

  • 实验内容:

1、使用np.loadtxt把txt文件中的数据变为ndarray,第一列是编号,在实验中不需要,使用切片去除。同时根据数据集的资料,列出每个属性名。

2、求信息熵。需要统计出总共有多少种类别,以及每类个数,使用字典存储查找快捷方便。按照公式进行计算即可。

3、根据选定的属性及该属性取值value,对数据集进行筛选,只保留该属性列值为value的行,并且去除该属性列,因为后面还要递归处理,这列还保留的话,可能会在一个分支上多次使用这个属性进行划分。去除列可以使用切片,之后再用np.hstack水平方向合并。

4、给出一个列表,求出其中哪个元素出现的次数最多。可以用于递归的数据只有一种属性,此时为叶节点,类别确定为出现的次数最多的类别。使用内置的collections库中的Counter进行统计并输出。

5、根据课本上的原理进行划分属性的选择,应当计算并选择增益率最大的属性。对每个属性求增益率:使用两层循环,外层枚举属性,内层统计该属性有多少种、占比、固有值。根据公式求出每个的增益率,选取增益率最大的属性作为划分属性。

6、根据书上的伪代码创建树。首先是两种特殊情况:如果只有一种类别,那么这个分支的所有样本都属于这个类别,变为叶节点;

如果样本在A上取值相同,就选择D中样本数最多的类。

与这两种情况都不相符的话,就要从A中选取增益率最大的属性进行划分,调用上面实现的函数choseBest即可得出。然后根据该属性取值不同,生成不同分支,再递归求解其子树。

实验中遇到的问题及解决:

在建树的处理过程中,选定了某个属性需要在标签中将其去除,防止与后面的数据对应不上。在递归参数传递时,使用splitData函数选取该属性值相同的样本放在一个节点上,并去除该属性的列。lables的传递需要注意使用切片生成一个新对象或者使用deepcopy出一个新的,否则传递指针导致在子进程中修改lables,父进程出错。

实验结果:

可以看到各个特征的信息增益率,由于决策树是多层的,可能在不同的分支对不同的属性信息增益率不同。但是前四行是第一次选定,可以看到tear的增益率最高,选定其作为划分依据,后面类似,与实验结果吻合。

根据资料搜集,可以使用matplotlib进行决策树的绘制,这样更加清晰直观。

  • 实验原理:

  • 心得体会:

通过这次实验我学习了三种决策树的思想以及C4.5算法的具体实现过程,在实践中对这些算法有了更加清晰的理解和掌握。同时还熟悉了信息熵、增益率等概念和决策树具体创建和工作流程以及其合理性。这也是对python数据处理、程序设计的能力考验,有的操作调用方法就可以较快实现,如ndarray的合并、列表求出现次数最多的元素。

    在实验中也遇到了因为python的函数传参机制是传递指针,造成递归的子进程修改父进程变量的问题;对于伪代码转换为可执行代码的手足无措;以及学习matplotlib实现决策树的绘制等等。为以后的学习打下坚实的基础。

import numpy as np
from math import log2
import treePlotter
from collections import Counter


def loadData(path):
    dataset = np.loadtxt(path)
    lable = ['age', 'prescription', 'astigmatic', 'tear']
    return dataset[:, 1:], lable


def getEntropy(data):
    num = len(data)
    lable_num = {}
    for i in data:  # 统计总共多少类,及每类个数
        if i[-1] not in lable_num.keys():
            lable_num[i[-1]] = 0
        lable_num[i[-1]] += 1
    e = 0
    for key in lable_num:
        p = float(lable_num[key]) / num
        e -= p * log2(p)
    return e


def splitData(data, classfication, value):
    subdata = []
    for row in data:
        if row[classfication] == value:
            subdata.append(np.hstack((row[:classfication], row[classfication + 1:])))
    # print(subdata)
    return subdata


def maximumClass(class_list):
    cnt = Counter(class_list)
    max_class = cnt.most_common()[0][0]
    return max_class


def choseBest(data, lables):
    feat_num = len(data[0]) - 1
    all_ent = getEntropy(data)
    max_gain_rate = -1
    best_feat = -1
    for i in range(feat_num):  # 枚举每种属性做划分的增益率
        feat_value = set([row[i] for row in data])
        new_ent = 0
        IV = 0
        for j in feat_value:
            subdata = splitData(data, i, j)  # 将数据中属性i的值为j的行挑出来
            p = float(len(subdata)) / len(data)  # 求占比
            IV -= p * log2(p)
            new_ent += p * getEntropy(subdata)
        gain = all_ent - new_ent
        if len(feat_value) == 1:  # 只有一种属性值
            continue
        gain_rate = gain / IV
        print(lables[i], " : ", gain_rate)
        if gain_rate > max_gain_rate:
            max_gain_rate = gain_rate
            best_feat = i
    return best_feat


def createTree(data, lables):
    class_list = [row[-1] for row in data]
    # print(class_list)
    if class_list.count(class_list[0]) == len(class_list):  # 只有一种结果
        # print(class_list.count(class_list[0]))
        return class_list[0]
    if len(data[0]) == 1:  # 属性只有一种,取种类最多的
        return maximumClass(class_list)
    best_feat = choseBest(data, lables)  # 属性的编号
    lable = lables[best_feat]  # 属性的名
    tree = {lable: {}}
    lables.remove(lables[best_feat])  # 已经使用过该属性划分,在其递归子程序中不会再用到,但是其他分支会用到
    feat_value = set([row[best_feat] for row in data])
    for i in feat_value:
        tree[lable][i] = createTree(splitData(data, best_feat, i), lables[:])
        # 这里传入lables列表时使用深拷贝,防止子进程修改父进程的lables
    return tree


if __name__ == '__main__':
    data, lables = loadData(r'D:\pythonProject\Machine_learning\test3\隐形眼镜数据集\lenses_data.txt')
    tree = createTree(data, lables)
    print(tree)
    treePlotter.createplot(tree)
import matplotlib.pyplot as plt
from pylab import *

mpl.rcParams['font.sans-serif'] = ['SimHei']

"定义文本框和箭头格式"
decisionnode = dict(boxstyle="sawtooth", fc="0.8")
leafnode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

"绘制带箭头的注解"


def plotnode(nodetxt, centerpt, parentpt, nodetype):
    createplot.ax1.annotate(nodetxt,
                            xy=parentpt,
                            xycoords='axes fraction',
                            xytext=centerpt,
                            textcoords='axes fraction',
                            va='center',
                            ha='center',
                            bbox=nodetype,
                            arrowprops=arrow_args)


def createplot():
    fig = plt.figure(1, facecolor='white')  # 创建新图形
    fig.clf()  # 清空绘图区
    createplot.ax1 = plt.subplot(111, frameon=False)
    plotnode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionnode)
    plotnode('叶节点', (0.8, 0.1), (0.3, 0.8), leafnode)
    plt.show()


def getnumleafs(mytree):
    numleafs = 0
    firststr = list(mytree.keys())[0]
    seconddict = mytree[firststr]
    for key in seconddict.keys():
        if type(seconddict[key]).__name__ == 'dict':
            numleafs += getnumleafs(seconddict[key])
        else:
            numleafs += 1
    return numleafs


def gettreedepth(mytree):
    maxdepth = 0
    firststr = list(mytree.keys())[0]
    seconddict = mytree[firststr]
    for key in seconddict.keys():
        if type(seconddict[key]).__name__ == 'dict':
            thisdepth = 1 + getnumleafs(seconddict[key])
        else:
            thisdepth = 1
        if thisdepth > maxdepth:
            maxdepth = thisdepth
    return maxdepth


def plotmidtext(cntrpt, parentpt, txtstring):
    xmid = (parentpt[0] - cntrpt[0]) / 2.0 + cntrpt[0]
    ymid = (parentpt[1] - cntrpt[1]) / 2.0 + cntrpt[1]
    createplot.ax1.text(xmid, ymid, txtstring, va="center", ha="center", rotation=30)


def plottree(mytree, parentpt, nodetxt):
    numleafs = getnumleafs(mytree)
    depth = gettreedepth(mytree)
    firststr = list(mytree.keys())[0]
    cntrpt = (plottree.xoff + (1.0 + float(numleafs)) / 2.0 / plottree.totalw, plottree.yoff)
    plotmidtext(cntrpt, parentpt, nodetxt)
    plotnode(firststr, cntrpt, parentpt, decisionnode)
    seconddict = mytree[firststr]
    plottree.yoff = plottree.yoff - 1.0 / plottree.totald
    for key in seconddict.keys():
        if type(seconddict[key]).__name__ == 'dict':
            plottree(seconddict[key], cntrpt, str(key))
        else:
            plottree.xoff = plottree.xoff + 1.0 / plottree.totalw
            plotnode(seconddict[key], (plottree.xoff, plottree.yoff), cntrpt, leafnode)
            plotmidtext((plottree.xoff, plottree.yoff), cntrpt, str(key))
    plottree.yoff = plottree.yoff + 1.0 / plottree.totald


def createplot(intree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createplot.ax1 = plt.subplot(111, frameon=False, **axprops)
    "使用下两行代码计算树节点的摆放位置,将树绘制在水平方向和垂直方向的中心位置"
    plottree.totalw = float(getnumleafs(intree))  # 树的宽度
    plottree.totald = float(gettreedepth(intree))  # 树的高度
    plottree.xoff = -0.5 / plottree.totalw;
    plottree.yoff = 1.0  # 追踪已经绘制的节点位置,以及放置下一个节点的恰当位置
    plottree(intree, (0.5, 1.0), '')
    plt.show()

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值