python 决策树规则提取_python - 如何提取决策规则以在决策树分类器中定义最终/终端节点并打印将使用numpy数组的代码 - 堆栈内存溢出...

我正在尝试提取决策规则以预测终端节点,并打印将使用pandas numpy数组预测终端节点编号的代码。 我找到了一种可以在( 从scikit-learn决策树中提取决策规则? ) 提取规则的解决方案,但是我不确定如何扩展它以产生所需的内容。 解决方案的链接有很多答案。 这是我要指的一个问题的描述。

import pandas as pd

import numpy as np

from sklearn.tree import DecisionTreeClassifier

# dummy data:

df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

df

# create decision tree

dt = DecisionTreeClassifier(random_state=0, max_depth=5, min_samples_leaf=1)

dt.fit(df.loc[:,('col1','col2')], df.dv)

#This function first starts with the nodes (identified by -1 in the child arrays) and then recursively finds the parents.

#I call this a node's 'lineage'. Along the way, I grab the values I need to create if/then/else SAS logic:

def get_lineage(tree, feature_names):

left = tree.tree_.children_left

right = tree.tree_.children_right

threshold = tree.tree_.threshold

features = [feature_names[i] for i in tree.tree_.feature]

# get ids of child nodes

idx = np.argwhere(left == -1)[:,0]

def recurse(left, right, child, lineage=None):

if lineage is None:

lineage = [child]

if child in left:

parent = np.where(left == child)[0].item()

split = 'l'

else:

parent = np.where(right == child)[0].item()

split = 'r'

lineage.append((parent, split, threshold[parent], features[parent]))

if parent == 0:

lineage.reverse()

return lineage

else:

return recurse(left, right, parent, lineage)

for child in idx:

for node in recurse(left, right, child):

print (node)

get_lineage(dt, df.columns)

当您运行代码时,它将提供以下信息:

(0, 'l', 3.5, 'col2')

1

(0, 'r', 3.5, 'col2')

(2, 'l', 1.5, 'col1')

3

(0, 'r', 3.5, 'col2')

(2, 'r', 1.5, 'col1')

(4, 'l', 2.5, 'col1')

5

(0, 'r', 3.5, 'col2')

(2, 'r', 1.5, 'col1')

(4, 'r', 2.5, 'col1')

6

如何扩展它以打印如下内容:

df['Terminal_Node_Num']=np.where(df.loc[:,'col2']<=3.5,1,0)

df['Terminal_Node_Num']=np.where(((df.loc[:,'col2']>3.5) & (df.loc[:,'col1']

<=1.5)), 3, df['Terminal_Node_Num'])

df['Terminal_Node_Num']=np.where(((df.loc[:,'col2']>3.5) &

(df.loc[:,'col1']>1.5) & (df.loc[:,'col1']<=2.5)), 5,

df['Terminal_Node_Num'])

df['Terminal_Node_Num']=np.where(((df.loc[:,'col2']>3.5)`enter code here`(df.loc[:,'col1']>1.5) & (df.loc[:,'col1']>2.5)), 6, df['Terminal_Node_Num'])

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值