《统计学习方法》手撕决策树ID3,C4.5

废话不多说,直接上代码
详细原理见
《统计学习方法》第五章决策树总结

import numpy as np


class DecisionTree(object):
    def __init__(self, tree_type):
        self.Tree = None
        self.tree_type = tree_type

    def build_ID3(self, D, A, e):
        """
        :param D: 训练数据集
        :param A: 特征集
        :param e: 阈值
        :return: 决策树T
        """
        n = D.shape[0]

        tree = {'is_leaf': False}
        X = D[:,:-1]
        Y = D[:,-1]

        # 判断是否符合终止条件
        # 所有样本属于同一类别、A为空集
        if len(np.unique(Y)) == 1 or len(A)==0:
            tree['is_leaf'] = True
            # 将D中实例数最大的类作为该节点标记
            tree['label'] = self.majority_vote(Y)
            return tree

        # 计算原数据集的信息熵
        origin_entropy = self.entropy(Y)
        # 保存使用每一个属性分割后带来的信息增益或信息增益比
        gains = []

        for i in range(len(A)):
            uniques, counts = np.unique(X[:, i], return_counts=True)
            # 针对每一个属性计算分割后的信息熵
            entropy = 0
            for j in range(uniques.shape[0]):
                value, count = uniques[j], counts[j]
                entropy += (count / n) * self.entropy(Y[X[:, i] == value])
            # 计算信息增益
            gain = origin_entropy - entropy
            if self.tree_type == 'ID3':
                gains.append(gain)
            if self.tree_type == 'C4.5':
                h = self.entropy(X[:, i])
                gains.append(gain / h)

        # 如果该特征最优的信息增益小于阈值,则返回决策树
        if max(gains)<e:
            tree['is_leaf'] = True
            # 将D中实例数最大的类作为该节点标记
            tree['label'] = self.majority_vote(Y)
            return tree

        col = np.argmax(gains)
        Ag = A[col]  # 挑选信息增益最大的特征切分
        tree['Ag_name'] = Ag
        tree['Ag_index'] = col
        tree['children'] = {}
        uniques = np.unique(X[:, col])
        for value in uniques:
            id = X[:, col] == value
            # 水平拼接训练集构成子集,即去除最优特征的这一列
            Di = np.hstack((D[id, :col], D[id, col + 1:]))
            # 以A-Ag为新的特征集
            A_Ag = np.hstack((A[:col], A[col + 1:]))
            tree['children'][value] = self.build_ID3(Di, A_Ag, e)

        return tree


    def majority_vote(self, targets):
        if len(targets) == 0:
            return
        uniques, counts = np.unique(targets, return_counts=True)
        return uniques[np.argmax(counts)]

    def entropy(self, D):
        _, C = np.unique(D, return_counts=True) # 返回无重复元素列表以及每个元素在旧列表里各自出现了几次
        p = C / D.shape[0]
        H = -(p * np.log2(p)).sum()
        return H

    def predict(self, tree, data):
        if tree['is_leaf']:
            return tree['label']

        return self.predict(tree['children'][data[tree['Ag_index']]],
                                np.hstack((data[:tree['Ag_index']], data[tree['Ag_index'] + 1:])))

data=np.array([['青年','否','否','一般','否'],
      ['青年','否','否','好','否'],
      ['青年','是','否','好','是'],
      ['青年','是','是','一般','是'],
      ['青年','否','否','一般','否'],
      ['中年','否','否','一般','否'],
      ['中年','否','否','好','否'],
      ['中年','是','是','好','是'],
      ['中年','否','是','非常好','是'],
      ['中年','否','是','非常好','是'],
      ['老年','否','是','非常好','是'],
      ['老年','否','是','好','是'],
      ['老年','是','否','好','是'],
      ['老年','是','否','非常好','是'],
      ['老年','否','否','一般','否']])


A=['年龄','有工作','有自己房子','信贷情况']
tree = DecisionTree('C4.5')
tree.Tree = tree.build_ID3(data, A, 0)
print(tree.Tree)
print(tree.predict(tree.Tree, ['青年','否','是','非常好']))
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
C语言是一种广泛使用的编程语言,它具有高效、灵活、可移植性强等特点,被广泛应用于操作系统、嵌入式系统、数据库、编译器等领域的开发。C语言的基本语法包括变量、数据类型、运算符、控制结构(如if语句、循环语句等)、函数、指针等。在编写C程序时,需要注意变量的声明和定义、指针的使用、内存的分配与释放等问题。C语言中常用的数据结构包括: 1. 数组:一种存储同类型数据的结构,可以进行索引访问和修改。 2. 链表:一种存储不同类型数据的结构,每个节点包含数据和指向下一个节点的指针。 3. 栈:一种后进先出(LIFO)的数据结构,可以通过压入(push)和弹出(pop)操作进行数据的存储和取出。 4. 队列:一种先进先出(FIFO)的数据结构,可以通过入队(enqueue)和出队(dequeue)操作进行数据的存储和取出。 5. 树:一种存储具有父子关系的数据结构,可以通过中序遍历、前序遍历和后序遍历等方式进行数据的访问和修改。 6. 图:一种存储具有节点和边关系的数据结构,可以通过广度优先搜索、深度优先搜索等方式进行数据的访问和修改。 这些数据结构在C语言中都有相应的实现方式,可以应用于各种不同的场景。C语言中的各种数据结构都有其优缺点,下面列举一些常见的数据结构的优缺点: 数组: 优点:访问和修改元素的速度非常快,适用于需要频繁读取和修改数据的场合。 缺点:数组的长度是固定的,不适合存储大小不固定的动态数据,另外数组在内存中是连续分配的,当数组较大时可能会导致内存碎片化。 链表: 优点:可以方便地插入和删除元素,适用于需要频繁插入和删除数据的场合。 缺点:访问和修改元素的速度相对较慢,因为需要遍历链表找到指定的节点。 栈: 优点:后进先出(LIFO)的特性使得栈在处理递归和括号匹配等问题时非常方便。 缺点:栈的空间有限,当数据量较大时可能会导致栈溢出。 队列: 优点:先进先出(FIFO)的特性使得

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Hilbob

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值