机器学习值决策树算法(下)-ID3实现

1.决策树展示

决策树展示
上一篇文章,我们介绍了决策树的基本概念,信息基本概念以及如何通过选择最优分类生成决策树,本篇文章首先介绍如何根据决策树通过matplotlib库来实现决策树图的展现。
matplotlib介绍
matplotlib是一个相对来说比较庞大的库,常用于机器学习中,用于直观的展示数据,例如本次的决策树展示,本章中介绍如何用该库创建树枝节点,以及如何展示文本信息。
from matplotlib import pyplot as plt


class NodePlot(object):
    def __init__(self):
        self.ax = None
        self.decision_node = {'boxstyle': 'sawtooth', 'fc': '0.8'} # 决策节点类型
        self.leaf_node = {'boxstyle': 'round4', 'fc': '0.8'} # 叶子节点类型
        self.arrow_args = {'arrowstyle': '<-'} # 箭头类型
        self.fig = None

    def plot_node(self, node_text, center_pt, parent_pt, node_type):
        self.ax.annotate(node_text, xy=parent_pt, xycoords='axes fraction',
                         xytext=center_pt, textcoords='axes fraction',
                         va='center', ha='center', bbox=node_type, arrowprops=self.arrow_args)


class TreePlot(NodePlot):
    def create_tree(self, tree):
        self.fig = plt.figure(1, facecolor='white') # 生成画板
        self.fig.clf() 
        self.ax = plt.subplot(111, frameon=False) 
        self.max_width = float(self.get_leaf_nums(tree)) # 计算树的宽度
        self.max_depth = float(self.get_tree_depth(tree)) # 计算树的高度
        self.x_off = 0.0
        self.y_off = 1.0
        self.draw_tree(tree, None, '') # 给定初始点,开始绘画决策树
        plt.show()

    def get_tree_depth(self, tree):
        max_depth = 0
        first_key = tree.keys()[0]
        second_dict = tree[first_key]
        for item in second_dict: # 遍历决策树
            if isinstance(second_dict[item], dict):
                depth = 1 + self.get_tree_depth(second_dict[item])
            else:
                depth = 1
            max_depth = max(max_depth, depth) # 取各个分支节点下的最大高度
        return max_depth

    def get_leaf_nums(self, tree):
        max_leafs = 0
        if isinstance(tree, dict):
            for item in tree:
                max_leafs += self.get_leaf_nums(tree[item])
        else:
            max_leafs += 1 # 每个叶子节点,宽度加1
        return max_leafs

    def plot_middle(self, cntr_pt, parent_pt, text):
        text_x = (parent_pt[0] - cntr_pt[0]) / 2 + cntr_pt[0]
        text_y = (parent_pt[1] - cntr_pt[1]) / 2 + cntr_pt[1]
        self.ax.text(text_x, text_y, text)

    def draw_tree(self, tree, parent_node, node_text):
        width = float(self.get_leaf_nums(tree))
        first_key = tree.keys()[0]
        cntr_pt = (self.x_off + (1 + width) / 2 / self.max_width, self.y_off)

        if parent_node is None:
            self.plot_node(first_key, cntr_pt, cntr_pt, self.decision_node)
        else:
            self.plot_middle(cntr_pt, parent_node, node_text)
            self.plot_node(first_key, cntr_pt, parent_node, self.decision_node)

        self.y_off -= 1 / self.max_depth

        second_item = tree[first_key]

        for item in second_item.keys():
            if isinstance(second_item[item], dict):
                self.draw_tree(second_item[item], cntr_pt, str(item))
            else:
                self.x_off += 1 / self.max_width
                self.plot_node(second_item[item], (self.x_off, self.y_off), cntr_pt, self.leaf_node)
                self.plot_middle((self.x_off, self.y_off), cntr_pt, item)
        self.y_off += 1 / self.max_depth
>>> data = [[1, 2, 'yes'], [2, 3, 'unknown'], [4, 15, 'no'], [12, 2, 'no'], [2, 3, 'yes'], [2, 4, 'unknown']]
>>> label = ['x', 'y', 'z']

生成结果如下:
图1-决策树生成图

2.决策树分类

决策树分类
根据决策树分类的一般流程如下:
Created with Raphaël 2.1.0 创建数据集 根据数据集创建决策树 将数据带入决策树 数据分类

: 分类代码如下:

def classify(input_tree, labels, test_vec):
    first_str = input_tree.keys()[0]

    second_dict = input_tree[first_str]
    label_index = labels.index(first_str)
    class_label = ''

    for key in second_dict.keys():
        if key == test_vec[label_index]:
            if isinstance(second_dict[key], dict):
                class_label = classify(second_dict[key], labels, test_vec)
            else:
                class_label = second_dict[key]
    return class_label
>>> data = [[1, 2, 'yes'], [2, 3, 'unknown'], [4, 15, 'no'], [12, 2, 'no'], [2, 3, 'yes'], [2, 4, 'unknown']]
>>> label = ['x', 'y', 'z']

3.决策树存储

决策树存储
当数据集很大时,每次生成决策树耗费的时间过长,而且训练数据集不常变动时,可以通过pickle模块存储决策树。:
存储代码如下:
import pickle


def save_tree(tree):
    try:
        pickle.dump(tree, 'C:\\tree')
        return True
    except Exception, e:
        return None


def get_tree(f):
    try:
        return pickle.load(f)
    except Exception, e:
        return None

4.总结

  • 本文介绍了如何通过matplotlib展示决策树,如何使用决策树进行数据分类,如何存储决策树。
  • 下一篇文章会介绍贝叶斯算法。
  • 看过很多大牛的博客,都很钟爱决策树算法,但决策树算法也有坑,例如数据过于杂乱时,会导致树一直生成,树过于冗余,后面介绍如何截断决策树。

5.参考

  • [机器学习实战]
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值