机器学习技巧_决策树规则输出成SQL

将sklearn的决策树规则输出成SQL

主要使用sklearn.tree._tree 读取决策树的信息

1、输出成SQL的主要函数

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import _tree


def tree_to_code(tree, feature_names):
    tree_ = tree.tree_  # 得到tree
    feature_name = [ ## 找出叶节点结果
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"  ## 如果是叶  (没有判断条件 TREE_UNDEFINED)
        for i in tree_.feature  # array([ 3, -2,  3,  2, -2, -2,  2, -2, -2], dtype=int64)
    ]
    # 递归打印
    Case_list = []
    def to_sql(node, depth):
        """
        递归打印
        tree_ 外函数参数
        Case_list 外函数参数
        feature_name 外函数参数
        """
        tp = np.array([0, 1, 2])
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED: ## 如果是节点   (有判断条件) 
            name = feature_name[node]
            threshold = tree_.threshold[node] ## 节点阈值
            ## 左子节点 
            message_left = "{}CASE WHEN {} <= {} THEN".format(indent, name, threshold)
            Case_list.append(message_left)
            to_sql(tree_.children_left[node], depth + 1 )
            ## 右子节点 
            message_right = "{}ELSE  -- if {} > {}".format(indent, name, threshold)
            Case_list.append(message_right)
            to_sql(tree_.children_right[node], depth + 1 )

        else:  ## 如果是叶  (没有判断条件 TREE_UNDEFINED) 
        	# 直接读取叶的种类
            tp_i = tp[tree_.value[node][0] == tree_.value[node][0].max()][0]
            message_left = "{}'Class_{}'".format(indent, tp_i)
            Case_list.append(message_left)

    to_sql(0, 1)

    return  Case_list
 
def find_END(sql_list):
    """
    找到需要补充END 的 行
    """
    need_end = [] # 记录需要加END的index
    n = len(sql_list)
    for idx in range(n):
        if (idx + 1 < n): 
            if ('ELSE' in sql_list[idx]) and ('Class_' in sql_list[idx + 1]):
                need_end.append(idx + 1)
    return need_end

def add_END(sql_list, need_end):
    """
    填上 END 
    """
    n = len(sql_list)
    for i in range(n):
        if i in need_end:
            sql_list[i] = sql_list[i] + ' END'
    
    ## 填补末尾的end 
    n_c = 0 
    for i in sql_list:
        if 'CASE' in i:
            n_c += 1
    sql_list.append( ' END ' * (n_c - len(need_end) ) )
    return sql_list

2、用iris数据集测试

from sklearn.datasets import load_iris 
iris = load_iris()
dtree = DecisionTreeClassifier(max_depth=3)
dtree.fit(iris.data, iris.target)
dtree.feature_importances_


if __name__ == "__main__":    
    feature_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']

    Case_list = tree_to_code(dtree, feature_names)
    need_end = find_END(Case_list)
    Case_list = add_END(Case_list, need_end)

    from six.moves  import reduce
    sql_final = reduce(lambda a, b: str(a) + "\n" + str(b) , Case_list)
    print(sql_final)


3、用iris数据集测试——SQL&数据库结果

CASE WHEN petal_width <= 0.800000011920929 THEN
    'Class_0'
  ELSE  -- if petal_width > 0.800000011920929
    CASE WHEN petal_width <= 1.75 THEN
      CASE WHEN petal_length <= 4.950000047683716 THEN
        'Class_1'
      ELSE  -- if petal_length > 4.950000047683716
        'Class_2' END
    ELSE  -- if petal_width > 1.75
      CASE WHEN petal_length <= 4.8500001430511475 THEN
        'Class_2'
      ELSE  -- if petal_length > 4.8500001430511475
        'Class_2' END
 END  END

数据库结果

在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Scc_hy

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值