将sklearn的决策树规则输出成SQL
主要使用sklearn.tree._tree
读取决策树的信息
1、输出成SQL的主要函数
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import _tree
def tree_to_code(tree, feature_names):
tree_ = tree.tree_ # 得到tree
feature_name = [ ## 找出叶节点结果
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" ## 如果是叶 (没有判断条件 TREE_UNDEFINED)
for i in tree_.feature # array([ 3, -2, 3, 2, -2, -2, 2, -2, -2], dtype=int64)
]
# 递归打印
Case_list = []
def to_sql(node, depth):
"""
递归打印
tree_ 外函数参数
Case_list 外函数参数
feature_name 外函数参数
"""
tp = np.array([0, 1, 2])
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED: ## 如果是节点 (有判断条件)
name = feature_name[node]
threshold = tree_.threshold[node] ## 节点阈值
## 左子节点
message_left = "{}CASE WHEN {} <= {} THEN".format(indent, name, threshold)
Case_list.append(message_left)
to_sql(tree_.children_left[node], depth + 1 )
## 右子节点
message_right = "{}ELSE -- if {} > {}".format(indent, name, threshold)
Case_list.append(message_right)
to_sql(tree_.children_right[node], depth + 1 )
else: ## 如果是叶 (没有判断条件 TREE_UNDEFINED)
# 直接读取叶的种类
tp_i = tp[tree_.value[node][0] == tree_.value[node][0].max()][0]
message_left = "{}'Class_{}'".format(indent, tp_i)
Case_list.append(message_left)
to_sql(0, 1)
return Case_list
def find_END(sql_list):
"""
找到需要补充END 的 行
"""
need_end = [] # 记录需要加END的index
n = len(sql_list)
for idx in range(n):
if (idx + 1 < n):
if ('ELSE' in sql_list[idx]) and ('Class_' in sql_list[idx + 1]):
need_end.append(idx + 1)
return need_end
def add_END(sql_list, need_end):
"""
填上 END
"""
n = len(sql_list)
for i in range(n):
if i in need_end:
sql_list[i] = sql_list[i] + ' END'
## 填补末尾的end
n_c = 0
for i in sql_list:
if 'CASE' in i:
n_c += 1
sql_list.append( ' END ' * (n_c - len(need_end) ) )
return sql_list
2、用iris
数据集测试
from sklearn.datasets import load_iris
iris = load_iris()
dtree = DecisionTreeClassifier(max_depth=3)
dtree.fit(iris.data, iris.target)
dtree.feature_importances_
if __name__ == "__main__":
feature_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
Case_list = tree_to_code(dtree, feature_names)
need_end = find_END(Case_list)
Case_list = add_END(Case_list, need_end)
from six.moves import reduce
sql_final = reduce(lambda a, b: str(a) + "\n" + str(b) , Case_list)
print(sql_final)
3、用iris
数据集测试——SQL&数据库结果
CASE WHEN petal_width <= 0.800000011920929 THEN
'Class_0'
ELSE -- if petal_width > 0.800000011920929
CASE WHEN petal_width <= 1.75 THEN
CASE WHEN petal_length <= 4.950000047683716 THEN
'Class_1'
ELSE -- if petal_length > 4.950000047683716
'Class_2' END
ELSE -- if petal_width > 1.75
CASE WHEN petal_length <= 4.8500001430511475 THEN
'Class_2'
ELSE -- if petal_length > 4.8500001430511475
'Class_2' END
END END
数据库结果