1.决策树展示
-
决策树展示
- 上一篇文章,我们介绍了决策树的基本概念,信息基本概念以及如何通过选择最优分类生成决策树,本篇文章首先介绍如何根据决策树通过matplotlib库来实现决策树图的展现。 matplotlib介绍
- matplotlib是一个相对来说比较庞大的库,常用于机器学习中,用于直观的展示数据,例如本次的决策树展示,本章中介绍如何用该库创建树枝节点,以及如何展示文本信息。
from matplotlib import pyplot as plt
class NodePlot(object):
def __init__(self):
self.ax = None
self.decision_node = {'boxstyle': 'sawtooth', 'fc': '0.8'} # 决策节点类型
self.leaf_node = {'boxstyle': 'round4', 'fc': '0.8'} # 叶子节点类型
self.arrow_args = {'arrowstyle': '<-'} # 箭头类型
self.fig = None
def plot_node(self, node_text, center_pt, parent_pt, node_type):
self.ax.annotate(node_text, xy=parent_pt, xycoords='axes fraction',
xytext=center_pt, textcoords='axes fraction',
va='center', ha='center', bbox=node_type, arrowprops=self.arrow_args)
class TreePlot(NodePlot):
def create_tree(self, tree):
self.fig = plt.figure(1, facecolor='white') # 生成画板
self.fig.clf()
self.ax = plt.subplot(111, frameon=False)
self.max_width = float(self.get_leaf_nums(tree)) # 计算树的宽度
self.max_depth = float(self.get_tree_depth(tree)) # 计算树的高度
self.x_off = 0.0
self.y_off = 1.0
self.draw_tree(tree, None, '') # 给定初始点,开始绘画决策树
plt.show()
def get_tree_depth(self, tree):
max_depth = 0
first_key = tree.keys()[0]
second_dict = tree[first_key]
for item in second_dict: # 遍历决策树
if isinstance(second_dict[item], dict):
depth = 1 + self.get_tree_depth(second_dict[item])
else:
depth = 1
max_depth = max(max_depth, depth) # 取各个分支节点下的最大高度
return max_depth
def get_leaf_nums(self, tree):
max_leafs = 0
if isinstance(tree, dict):
for item in tree:
max_leafs += self.get_leaf_nums(tree[item])
else:
max_leafs += 1 # 每个叶子节点,宽度加1
return max_leafs
def plot_middle(self, cntr_pt, parent_pt, text):
text_x = (parent_pt[0] - cntr_pt[0]) / 2 + cntr_pt[0]
text_y = (parent_pt[1] - cntr_pt[1]) / 2 + cntr_pt[1]
self.ax.text(text_x, text_y, text)
def draw_tree(self, tree, parent_node, node_text):
width = float(self.get_leaf_nums(tree))
first_key = tree.keys()[0]
cntr_pt = (self.x_off + (1 + width) / 2 / self.max_width, self.y_off)
if parent_node is None:
self.plot_node(first_key, cntr_pt, cntr_pt, self.decision_node)
else:
self.plot_middle(cntr_pt, parent_node, node_text)
self.plot_node(first_key, cntr_pt, parent_node, self.decision_node)
self.y_off -= 1 / self.max_depth
second_item = tree[first_key]
for item in second_item.keys():
if isinstance(second_item[item], dict):
self.draw_tree(second_item[item], cntr_pt, str(item))
else:
self.x_off += 1 / self.max_width
self.plot_node(second_item[item], (self.x_off, self.y_off), cntr_pt, self.leaf_node)
self.plot_middle((self.x_off, self.y_off), cntr_pt, item)
self.y_off += 1 / self.max_depth
>>> data = [[1, 2, 'yes'], [2, 3, 'unknown'], [4, 15, 'no'], [12, 2, 'no'], [2, 3, 'yes'], [2, 4, 'unknown']]
>>> label = ['x', 'y', 'z']
生成结果如下:
2.决策树分类
-
决策树分类
- 根据决策树分类的一般流程如下:
: 分类代码如下:
def classify(input_tree, labels, test_vec):
first_str = input_tree.keys()[0]
second_dict = input_tree[first_str]
label_index = labels.index(first_str)
class_label = ''
for key in second_dict.keys():
if key == test_vec[label_index]:
if isinstance(second_dict[key], dict):
class_label = classify(second_dict[key], labels, test_vec)
else:
class_label = second_dict[key]
return class_label
>>> data = [[1, 2, 'yes'], [2, 3, 'unknown'], [4, 15, 'no'], [12, 2, 'no'], [2, 3, 'yes'], [2, 4, 'unknown']]
>>> label = ['x', 'y', 'z']
3.决策树存储
-
决策树存储
- 当数据集很大时,每次生成决策树耗费的时间过长,而且训练数据集不常变动时,可以通过pickle模块存储决策树。:
- 存储代码如下:
import pickle
def save_tree(tree):
try:
pickle.dump(tree, 'C:\\tree')
return True
except Exception, e:
return None
def get_tree(f):
try:
return pickle.load(f)
except Exception, e:
return None
4.总结
- 本文介绍了如何通过matplotlib展示决策树,如何使用决策树进行数据分类,如何存储决策树。
- 下一篇文章会介绍贝叶斯算法。
- 看过很多大牛的博客,都很钟爱决策树算法,但决策树算法也有坑,例如数据过于杂乱时,会导致树一直生成,树过于冗余,后面介绍如何截断决策树。
5.参考
- [机器学习实战]