python决策树sklearn,使用sklearn python通过决策树提取数据点的规则路径

I'm using decision tree model and I want to extract the decision path for each data point in order to understand what caused the Y rather than to predict it.

How can I do that? Couldn't find any documentation.

解决方案

Here is an example using the iris dataset.

from sklearn.datasets import load_iris

from sklearn import tree

import graphviz

iris = load_iris()

clf = tree.DecisionTreeClassifier()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,

feature_names=iris.feature_names,

class_names=iris.target_names,

filled=True, rounded=True,

special_characters=True)

graph = graphviz.Source(dot_data)

#this will create an iris.pdf file with the rule path

graph.render("iris")

83947bcb48d37c2012ac356ac964605c.png

EDIT: the following code is from the sklearn documentation with some small changes to address your goal

import numpy as np

from sklearn.model_selection import train_test_split

from sklearn.datasets import load_iris

from sklearn.tree import DecisionTreeClassifier

iris = load_iris()

X = iris.data

y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)

estimator.fit(X_train, y_train)

# The decision estimator has an attribute called tree_ which stores the entire

# tree structure and allows access to low level attributes. The binary tree

# tree_ is represented as a number of parallel arrays. The i-th element of each

# array holds information about the node `i`. Node 0 is the tree's root. NOTE:

# Some of the arrays only apply to either leaves or split nodes, resp. In this

# case the values of nodes of the other type are arbitrary!

#

# Among those arrays, we have:

# - left_child, id of the left child of the node

# - right_child, id of the right child of the node

# - feature, feature used for splitting the node

# - threshold, threshold value at the node

n_nodes = estimator.tree_.node_count

children_left = estimator.tree_.children_left

children_right = estimator.tree_.children_right

feature = estimator.tree_.feature

threshold = estimator.tree_.threshold

# The tree structure can be traversed to compute various properties such

# as the depth of each node and whether or not it is a leaf.

node_depth = np.zeros(shape=n_nodes, dtype=np.int64)

is_leaves = np.zeros(shape=n_nodes, dtype=bool)

stack = [(0, -1)] # seed is the root node id and its parent depth

while len(stack) > 0:

node_id, parent_depth = stack.pop()

node_depth[node_id] = parent_depth + 1

# If we have a test node

if (children_left[node_id] != children_right[node_id]):

stack.append((children_left[node_id], parent_depth + 1))

stack.append((children_right[node_id], parent_depth + 1))

else:

is_leaves[node_id] = True

print("The binary tree structure has %s nodes and has "

"the following tree structure:"

% n_nodes)

for i in range(n_nodes):

if is_leaves[i]:

print("%snode=%s leaf node." % (node_depth[i] * "\t", i))

else:

print("%snode=%s test node: go to node %s if X[:, %s] <= %s else to "

"node %s."

% (node_depth[i] * "\t",

i,

children_left[i],

feature[i],

threshold[i],

children_right[i],

))

print()

# First let's retrieve the decision path of each sample. The decision_path

# method allows to retrieve the node indicator functions. A non zero element of

# indicator matrix at the position (i, j) indicates that the sample i goes

# through the node j.

node_indicator = estimator.decision_path(X_test)

# Similarly, we can also have the leaves ids reached by each sample.

leave_id = estimator.apply(X_test)

# Now, it's possible to get the tests that were used to predict a sample or

# a group of samples. First, let's make it for the sample.

# HERE IS WHAT YOU WANT

sample_id = 0

node_index = node_indicator.indices[node_indicator.indptr[sample_id]:

node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)

for node_id in node_index:

if leave_id[sample_id] == node_id: #

#continue #

print("leaf node {} reached, no decision here".format(leave_id[sample_id])) #

else: # < -- added else to iterate through decision nodes

if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):

threshold_sign = "<="

else:

threshold_sign = ">"

print("decision id node %s : (X[%s, %s] (= %s) %s %s)"

% (node_id,

sample_id,

feature[node_id],

X_test[sample_id, feature[node_id]], #

threshold_sign,

threshold[node_id]))

This will print at the end the following:

Rules used to predict sample 0:

decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011920929)

decision id node 2 : (X[0, 2] (= 5.1) > 4.949999809265137)

leaf node 4 reached, no decision here

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值