python实现决策树算法

import numpy as np
class treenode():
    def __init__(self):
        self.cld = []

    def add_cld(self, item):
        self.cld.append(item)

    def set_a_name(self, name):
        self.name = name

    def set_par(self, par):
        self.par = par


def comput_entD(d_y, d_n):
    d_v_num = d_y + d_n
    p_y = d_y / d_v_num
    p_n = d_n / d_v_num
    ent_D = -(p_y * np.log2(p_y) + p_n * np.log2(p_n))
    if (p_y == 0) | (p_y == 1):
        ent_D = 0
    return ent_D


def get_optional(D, A, train_data):
    res = -1
    temp = -1
    where = np.where(train_data[D, -1] == '是', 1, 0)
    d_y = where @ where
    where = np.where(train_data[D, -1] == '否', 1, 0)
    d_n = where @ where
    ent_D = comput_entD(d_y, d_n)  # 计算D的信息熵,用于计算属性的信息增益
    for b in A:
        attr_arr = np.unique(train_data[D, b])  # 属性b的所有取值,比如b=色泽,attr_arr=[青绿,浅白,乌黑]
        ent_arr = np.array([0] * len(attr_arr), dtype=float)  # 暂存信息熵
        gain_arr = np.array([0] * len(attr_arr), dtype=float)  # 暂存|D_v|/|D|
        for k in range(len(attr_arr)):
            where = np.where((train_data[D, b] == attr_arr[k]) & (train_data[D, -1] == '是'), 1, 0)
            d_y = where @ where  # 属性b的取值为b*的样本中为正例的个数
            where = np.where((train_data[D, b] == attr_arr[k]) & (train_data[D, -1] == '否'), 1, 0)
            d_n = where @ where  # 属性b的取值为b*的样本中为负例的个数
            ent_arr[k] = comput_entD(d_y, d_n)  # 计算信息熵
            gain_arr[k] = (d_y + d_n) / len(D)
        gain_a = ent_D - ent_arr @ gain_arr  # 计算信息增益
        if gain_a > temp:  # 取最大的信息增益
            res = b
            temp = gain_a
    return res  # 返回取得最大信息增益的属性(以下标表示)


def get_type(D, train_data):
    where = np.where(train_data[D, -1] == '是', 1, 0)
    d_y = where @ where
    where = np.where(train_data[D, -1] == '否', 1, 0)
    d_n = where @ where
    if d_y >= d_n:
        return '是'
    return '否'


# param D 为当前样本集合(下标表示)
# param A 为当前属性集合(下标表示)
# param train_data 为所有训练样本
# param par 为当前结点的父属性的值
def create_tree(D, A, train_data, par):
    node = treenode()
    node.set_par(par)
    type_name = get_type(D, train_data)
    temp_arr = []
    for c in A:
        uni = np.unique(train_data[D, c])
        for u in uni:
            temp_arr.append(u)
    if len(np.unique(train_data[D, -1])) == 1:  # 若样本为同一类别
        node.set_a_name(str(train_data[D[0], -1]))  # 标记为叶节点,类别为样本中的类别
    elif len(A) == 0 or len(temp_arr) == len(A):  # 若A为空集,或D中的样本所有属性取值相同
        node.set_a_name(type_name)  # 将节点标记为叶节点,类别为D中样本最多的类
    else:
        a_hit = get_optional(D, A, train_data)  # 获取最优划分属性a^
        node.set_a_name(a_hit)  # 设置节点的名字为最优属性名
        A_a_hit = np.unique(train_data[0:, a_hit])
        # 遍历最优属性的所有取值
        for ai in A_a_hit:
            # 为每个属性值创建分支结点,对D进行划分,划分依据是D中样本属性的值相同的样本为一类
            D_ai = []
            for ind in D:
                if train_data[ind, a_hit] == ai:
                    D_ai.append(ind)
            D_ai = np.array(D_ai)
            # 创建孩子结点
            child_node = treenode()
            if len(D_ai) == 0:  # 如果划分的集合为空集
                child_node.set_a_name(type_name)  # 标记为叶结点,类别为D中样本最多的类
                child_node.set_par(ai)  # 设置父属性的值
            else:
                # 递归创建决策树结点
                child_node = create_tree(D_ai, np.setdiff1d(A, a_hit), train_data, ai)
            node.add_cld(child_node)  # 将孩子节点连接到父节点
    return node  # 返回根节点


if __name__ == '__main__':
    load_data = np.loadtxt('xigua2.0.txt', dtype=str, delimiter=',', encoding='utf8')
    root = create_tree(np.arange(0, len(load_data)), np.arange(0, len(load_data[0]) - 1), load_data, None)
    A_name = np.array(['色泽', '根蒂', '敲声', '纹理', '脐部', '触感'])
    bfs_arr = []
    bfs_arr.append(root)
    while len(bfs_arr) > 0:
        for i in range(len(bfs_arr)):
            node = bfs_arr.pop(0)
            for a in node.cld:
                bfs_arr.append(a)

            if type(node.name) == type('a'):
                if node.name == '是':
                    print('好瓜  ' + str(node.par))
                else:
                    print('坏瓜  ' + str(node.par))
            else:
                print(str(A_name[node.name]) + '  ' + str(node.par))
        print()

西瓜数据集2.0

青绿,蜷缩,浊响,清晰,凹陷,硬滑,是
乌黑,蜷缩,沉闷,清晰,凹陷,硬滑,是
乌黑,蜷缩,浊响,清晰,凹陷,硬滑,是
青绿,蜷缩,沉闷,清晰,凹陷,硬滑,是
浅白,蜷缩,浊响,清晰,凹陷,硬滑,是
青绿,稍蜷,浊响,清晰,稍凹,软粘,是
乌黑,稍蜷,浊响,稍糊,稍凹,软粘,是
乌黑,稍蜷,浊响,清晰,稍凹,硬滑,是
乌黑,稍蜷,沉闷,稍糊,稍凹,硬滑,否
青绿,硬挺,清脆,清晰,平坦,软粘,否
浅白,硬挺,清脆,模糊,平坦,硬滑,否
浅白,蜷缩,浊响,模糊,平坦,软粘,否
青绿,稍蜷,浊响,稍糊,凹陷,硬滑,否
浅白,稍蜷,沉闷,稍糊,凹陷,硬滑,否
乌黑,稍蜷,浊响,清晰,稍凹,软粘,否
浅白,蜷缩,浊响,模糊,平坦,硬滑,否
青绿,蜷缩,沉闷,稍糊,稍凹,硬滑,否
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值