使用sklearn调用DecisionTreeClassifier可以很简单的实现决策树算法,然而对于实现者而言并不知道树的结构是什么样子的,也不知道决策树模型如何做出的决策,本文将决策树模型以规则的形式展现出来,并且实现可视化,方便读者理解。
实现的代码如下:
# -*- coding: utf-8 -*-
from sklearn.externals.six import StringIO
import pydotplus
from sklearn import tree
from sklearn.tree import _tree
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
def draw_tree(model,name):
dot_data = StringIO()
tree.export_graphviz(model, out_file = dot_data)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf(name + ".pdf")
def tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print('feature_name:',feature_name)
print ("def tree({}):".format(", ".join(feature_names)))
def recurse(node, depth):
indent = " " * depth
# print('tree_.feature:',tree_.feature)
if tree_.feature[node] != _tree.TREE_UNDEFINED:
# print('tree_.feature[node]:',tree_.feature[node])
name = feature_name[node]
threshold = tree_.threshold[node]
print ("{}if {} <= {}:".format(indent, name, threshold))
recurse(tree_.children_left[node], depth + 1)
print ("{}else: # if {} > {}".format(indent, name, threshold))
recurse(tree_.children_right[node], depth + 1)
else:
print ("{}return {}".format(indent, tree_.value[node]))
recurse(0, 1)
if __name__ == '__main__':
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.3, random_state=0)
estimator = DecisionTreeClassifier(max_depth=2)
estimator.fit(X_train, y_train)
tree_to_code (estimator, ["length", "width", "height", "fps"])
draw_tree(estimator,'learn')
程序运行输出结果:
def tree(length, width, height, fps):
if fps <= 0.75:
return [[34. 0. 0.]]
else: # if fps > 0.75
if height <= 4.950000047683716:
return [[ 0. 31. 3.]]
else: # if height > 4.950000047683716
return [[ 0. 1. 36.]]
可视化图结果:
可能有些博主将代码原封不动的复制粘贴之后,运行报错,可以检查以下错误可能性:
1.代码所依赖的包不存在,这种情况根据所需进行安装,方法很简单,在命令行的话就是pip install **;
2.可视化决策树需要安装graphviz,安装方法自行百度。
如有问题可以留言?