sklearn生成的决策树转换为规则树

使用sklearn调用DecisionTreeClassifier可以很简单的实现决策树算法,然而对于实现者而言并不知道树的结构是什么样子的,也不知道决策树模型如何做出的决策,本文将决策树模型以规则的形式展现出来,并且实现可视化,方便读者理解。

实现的代码如下:

# -*-  coding: utf-8 -*-
from sklearn.externals.six import StringIO
import pydotplus
from sklearn import tree
from sklearn.tree import _tree
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

def draw_tree(model,name):
    dot_data = StringIO()
    tree.export_graphviz(model, out_file = dot_data)
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf(name + ".pdf")

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print('feature_name:',feature_name)
    print ("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "  " * depth
        # print('tree_.feature:',tree_.feature)
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            # print('tree_.feature[node]:',tree_.feature[node])
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print ("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print ("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print ("{}return {}".format(indent, tree_.value[node]))

    recurse(0, 1)
    
if __name__ == '__main__':
    iris = load_iris()
    X = iris.data
    y = iris.target
    X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.3, random_state=0)
    
    estimator = DecisionTreeClassifier(max_depth=2)
    estimator.fit(X_train, y_train)
    tree_to_code (estimator, ["length", "width", "height", "fps"])
    draw_tree(estimator,'learn')

程序运行输出结果:

def tree(length, width, height, fps):
  if fps <= 0.75:
    return [[34.  0.  0.]]
  else:  # if fps > 0.75
    if height <= 4.950000047683716:
      return [[ 0. 31.  3.]]
    else:  # if height > 4.950000047683716
      return [[ 0.  1. 36.]]

可视化图结果:

可能有些博主将代码原封不动的复制粘贴之后,运行报错,可以检查以下错误可能性:

1.代码所依赖的包不存在,这种情况根据所需进行安装,方法很简单,在命令行的话就是pip install **;

2.可视化决策树需要安装graphviz,安装方法自行百度。

如有问题可以留言?

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值