ID3决策树的Python实现以及可视化

算法介绍

ID3决策树是比较经典的决策树,在周志华的机器学习中,生成决策树的算法为:
在这里插入图片描述
算法的关键是如何选择最优划分属性,在ID3决策树中,用信息增益来指导决策树选择最优划分属性
首先定义信息熵为:
在这里插入图片描述
再定义信息增益为:
在这里插入图片描述
一般而言,信息增益越大,意味着使用属性a进行划分所获得的纯度提升越大,因此我们选择最大信息增益的属性作为最优划分属性。

Python实现思路

树的数据表示

既然要实现一棵树,首先要做的就是定义节点的数据结构,在C中,节点一般以结构体的形式存储,所以我们在Python中可以参考这一思路定义一个节点类:

class Node():
    """
    ID3决策树的节点
    parent -- 父节点
    sons -- 子节点集合,即在该节点最优划分属性下每个属性值的分支
    attrs -- 该节点下的最优划分属性
    parent_attrs_value -- 表示该节点是父节点哪一个属性的分支
    label -- 如果这个节点是叶子节点,则存放标签
    """
    def __init__(self, parent=None):
        self.parent = parent            
        self.sons = []                  
        self.attr = None                
        self.parent_attrs_value = None  
        self.label = None               

但在实际操作中,使用这一方法给代码的调试增加了难度,同时不利于后面用Graphviz包实现决策树的可视化,因此本文考虑使用另一种数据结构表示树,就是Python中的字典,我们先来看看对于西瓜书中给出的一颗决策树,用字典是如何表示的:
西瓜书中的一颗决策树:
在这里插入图片描述
对应的Python字典表示:

tree = {'纹理':
            {'清晰':
                {'根蒂':
                    {'蜷缩':
                        {'label':'是'}, 
                    '稍蜷':
                        {'色泽':
                            {'青绿':
                                {'label':'是'}, 
                            '乌黑':
                                {'触感':
                                    {'硬滑':
                                        {'label':'是'}, 
                                    '软粘':
                                        {'label':'否'}}},
                            '浅白':
                                {'label':'是'}}},
                    '硬挺':
                        {'label':'否'}}}, 
            '稍糊':
                {'触感':
                    {'硬滑':
                        {'label':'否'}, 
                    '软粘':
                        {'label':'是'}}}, 
            '模糊':
                {'label':'否'}}}

如何可视化决策树

在本文中,使用Graphviz包进行决策树的可视化,这里是官网文档
只需使用几条简单的代码便可将决策树的节点绘制出来:

g = graphviz.Digraph(name=,filename=, format='png')
g.node(name=, label=, fontname="Microsoft YaHei", shape=)
g.edge(tail_name, head_name, label=, fontname="Microsoft YaHei")
g.view()

要注意,如果决策树的信息是中文的,要在fontname参数中指定中文字体,不然会出现乱码

Python代码

DecesionTree.py

import numpy as np
import scipy.io as sio
from collections import Counter
from graphviz import Digraph

class DecisionTree():
    """
    一个构建ID3决策树的类
    attrs -- 存放属性的字典, 字典中,键为属性名,值为属性的取值,最后一个属性为标签属性
    X -- 训练数据
    y -- 标签
    attr_idx -- 属性列索引
    tree -- 生成的决策树,用字典形式存放
    node_name -- 用于对决策树的可视化,在graphviz中对节点的命名
    """
    def __init__(self):
        self.attrs = None
        self.X = None
        self.y = None
        self.attr_idx = None
        self.tree = {}
        self.node_name = "0"


    def get_attrs(self, data):
        """
        对数据集进行处理,得到属性与对应的属性取值
        args:
        data -- 输入的数据矩阵, shape=(samples+1, features), dtype='<U?', 其中,第一行为属性,最后一列为标签
        returns:
        attrs -- 存放属性的字典, 字典中,键为属性名,值为属性的取值 
        """
        attrs = {}
        for i in range(data.shape[1]):
            attrs_values = sorted(set(data[1:, i]))
            attrs[data[0][i]] = attrs_values

        self.attrs = attrs
        return attrs


    def generate_tree(self, data):
        """
        生成决策树
        args:
        data -- 输入的数据矩阵, shape=(samples+1, features+label), dtype='<U?', 其中,第一行为属性,最后一列为标签
        """
        self.X = data[1:, :-1]
        self.y = data[1:, -1]

        # 先创建一个不含label属性的纯变量属性字典
        pure_attrs = self.attrs.copy()
        del(pure_attrs['label'])
        # 构造一个只含属性名的列表
        attr_names = [attr_name for attr_name in pure_attrs.keys()]
        # 将属性名编号,方便查找其在数据中对应的列
        attr_idx = {}
        for num, attr in enumerate(attr_names):
            attr_idx[attr] = num
        self.attr_idx = attr_idx

        # 生成根节点
        self.tree['root_node'] = {}
        self._generate_tree(self.X, self.y, self.tree['root_node'], pure_attrs, attr_idx)
        self.tree = self.tree['root_node']
        

    def _generate_tree(self, X, y, node, attrs, attr_idx):
        """
        递归生成决策树
        args:
        X -- 输入的数据矩阵, shape=(samples, features), dtype='<U?'
        y -- 标签, shape=(samples, )
        parent_node -- 父节点,此次递归函数是父节点的某一个属性值的递归
        attrs -- 属性字典, 即从父节点分支到现在的节点时,还没有被划分的属性
        attr_idx -- 属性在数据中列索引
        """

        #--------- 如果训练集中样本全属于同一类别 ---------#
        if len(set(y.tolist())) == 1:
            node['label'] = y[0]
            return

        #-------- 如果属性集为空集或者训练集中样本在属性集上取值相同 ---------#
        # 判断训练集样本在属性集中取值是否相同
        same = True
        for i in range(X.shape[1]):
            if len(set(X[:, i].tolist())) > 1:
                same = False
        
        if not attrs or same:
            y_counter = Counter(y)
            most_y = y_counter.most_common()[0][0]
            node['label'] = most_y
            return

        #--------- 选择最优属性生成分支 ---------#
        # 选出最优划分属性
        optimal_attr = self.choose_optimal_attr(X, y, attrs, attr_idx)
        node[optimal_attr] = {}
        node = node[optimal_attr]
        # 对于最优划分属性下每个属性值
        for attr_value in attrs[optimal_attr]:
            # 生成分支
            node[attr_value] = {}
            # 令Dv表示X中在optimal_attr上取值为attr_value的样本子集
            Dv = X.copy()
            attr_value_idx = Dv[:, attr_idx[optimal_attr]] == attr_value
            Dv = Dv[attr_value_idx, :]
            y_Dv = y[attr_value_idx]
            Dv = np.delete(Dv, attr_idx[optimal_attr], 1)
            # 如果Dv为空
            if Dv.size == 0:
                # 将分支节点标记为叶节点,其类别标记为X中样本最多的类,即统计y
                y_counter = Counter(y)
                most_y = y_counter.most_common()[0][0]
                node[attr_value]['label'] = most_y
            else:
                # 更新属性字典
                new_attrs = attrs.copy()
                del(new_attrs[optimal_attr])
                # 更新属性列索引
                new_attr_names = [new_attr_name for new_attr_name in new_attrs.keys()]
                new_attr_idx = {}
                for num, attr in enumerate(new_attr_names):
                    new_attr_idx[attr] = num
                self._generate_tree(Dv, y_Dv, node[attr_value], new_attrs, new_attr_idx)


    def compute_Ent(self, y):
        """
        计算给出属性名列表所对应的所有样本的信息熵
        args:
        y -- 标签数组, shape=(samples, )
        return:
        Ent -- 样本的信息熵
        """
        Ent = 0
        m = np.size(y)
        for label in self.attrs['label']:
            pk = np.sum(y == label)
            pk = pk / m
            log2pk = np.log2(pk + 1e-8) # 防止算得0,导致返回nan
            Ent -= pk * log2pk
        return Ent


    def choose_optimal_attr(self, X, y, attrs, attr_idx):
        """
        选择最优划分属性 划分标准:属性的信息增益
        args: 
        X -- 输入的数据矩阵, shape=(samples, features), dtype='<U?'
        y -- 标签, shape=(samples, )
        attrs -- 属性字典
        attr_idx -- 属性在数据中列索引
        returns:
        max_gain_attr -- 最大的信息增益对应的属性
        """
        # 计算当前所含属性对应所有样本的信息熵
        Ent = self.compute_Ent(y)
        m = np.size(y)
        # 记录当前最大的信息增益以及对应的属性
        max_gain = 0
        max_gain_attr = None
        
        # 计算每一个属性的信息增益
        for attr, idx in attr_idx.items():
            x = X[:, idx]
            gain = Ent
            # 计算一个属性中每个属性值的信息熵
            for attr_value in attrs[attr]:
                _y = y[x==attr_value]
                if _y.size != 0:
                    ent = self.compute_Ent(_y)
                else:
                    ent = 0
                gain -= np.size(_y) / m * ent
            if gain > max_gain:
                max_gain = gain
                max_gain_attr = attr
                
        return max_gain_attr


    def predict(self, predict_x):
        """
        预测样本结果
        args:
        predict_x -- 预测样本数据矩阵 shape=(samples, features)
        returns:
        predict_y -- 样本的预测结果 shape=(samples, )
        """
        s = predict_x.shape[0]
        predict_y = []
        for i in range(s):
            node = self.tree
            while(1):
                if 'label' in node.keys():
                    predict_y.append(node['label'])
                    break
                elif list(node.keys())[0] in self.attrs.keys():
                    attr = list(node.keys())[0]
                    idx = self.attr_idx[attr]
                    node = node[attr]
                else:
                    node = node[predict_x[i, idx]]
                    # 如果测试样本中的属性值在没有在训练集中出现,用下面的代码处理这种情况
                    # attr_value = predict_x[i, idx]
                    # if attr_value in node.keys():
                    #    node = node[attr_value]
                    # else:
                    #    node = node[list(node.keys())[0]]
        return predict_y


    def tree_traversal(self, g, parent_node, parent_node_name, parent_attr, parent_attr_value):
        """
        对树进行遍历,生成可视化的节点
        g -- 要绘制的有向图
        parent_node -- 父节点
        parent_node_name -- 父节点在有向图中的代号
        parent_attr -- 父节点的属性
        parent_attr_value -- 父节点到该节点的属性值
        """
        if (parent_attr and parent_attr_value) is None:
            if 'label' in parent_node.keys():
                g.node(name=self.node_name, label=parent_node['label'], fontname="Microsoft YaHei")
                return
            else:
                attr = list(parent_node.keys())[0]
                node = parent_node[attr]
                parent_node_name = "0"
                for attr_value in node.keys():
                    self.tree_traversal(g, node[attr_value], parent_node_name, attr, attr_value)
        else:
            if 'label' in parent_node.keys():
                g.node(name=parent_node_name, label=parent_attr, fontname="Microsoft YaHei", shape='box')
                self.node_name = str(int(self.node_name) + 1)
                g.node(name=self.node_name, label=parent_node['label'], fontname="Microsoft YaHei")
                g.edge(parent_node_name, self.node_name, label=parent_attr_value, fontname="Microsoft YaHei")
            else:
                attr = list(parent_node.keys())[0]
                g.node(name=parent_node_name, label=parent_attr, fontname="Microsoft YaHei", shape='box')
                self.node_name = str(int(self.node_name) + 1)
                g.node(name=self.node_name, label=attr, fontname="Microsoft YaHei", shape='box')
                g.edge(parent_node_name, self.node_name, label=parent_attr_value, fontname="Microsoft YaHei")
                node = parent_node[attr]
                parent_node_name = self.node_name
                for attr_value in node.keys():
                    self.tree_traversal(g, node[attr_value], parent_node_name, attr, attr_value)


    def tree_visualize(self, file_name=None):
        """
        将决策树可视化
        args:
        file_name -- 若给出该参数,则将决策树保存为file_name的图片
        """
        if file_name:
            g = Digraph("Decision Tree", filename=file_name, format='png')
        else:
            g = Digraph("Decision Tree")
        self.tree_traversal(g, self.tree, None, None, None)
        g.view()


if __name__ == "__main__":
    pass

主函数,以西瓜树的西瓜数据集为例生成决策树,原数据集是Matlab的cell数组,并以mat文件存放,因此需要预处理一下:

import numpy as np
import scipy.io as sio
from DecisionTree import DecisionTree

def preprocess():
    raw_data = sio.loadmat('watermelon.mat')
    raw_data = raw_data['watermelon']
    data = np.zeros(raw_data.shape, dtype='<U20')

    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            data[i, j] = raw_data[i, j][0]

    data[0, -1] = 'label'
    return data

def main_1():
    """
    完整决策树
    """
    data = preprocess()
    DTree = DecisionTree()
    attrs = DTree.get_attrs(data)
    DTree.generate_tree(data)
    DTree.tree_visualize('watermelob_tree')

def main_2():
    """
    留出两个样本作为测试集
    """
    data = preprocess()
    train_idx = np.delete(np.arange(0, 18), [8, 17])
    test_idx = [8, 17]
    train_data = data[train_idx, :]
    test_data = data[test_idx, :]
    test_X = test_data[:, :-1]
    test_y = test_data[:, -1]

    DTree = DecisionTree()
    DTree.get_attrs(train_data)
    DTree.generate_tree(train_data)
    predict_y = DTree.predict(test_X)
    print(predict_y)
    DTree.tree_visualize('watermelon_tree_2')

main_1()

最终生成的决策树图片为:
在这里插入图片描述
到这里我们就成功地用Python实现了ID3决策树!

  • 6
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值