背景
通过打印树模型对测试样本的决策路径,实现树模型对预测样本的可解释性。
实现
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
X, y = make_classification(n_samples=1000,
n_features=6,
n_informative=3,
n_classes=2,
random_state=0,
shuffle=False)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
estimator = RandomForestClassifier(n_estimators=10,
random_state=0)
estimator.fit(X_train, y_train)
# The decision estimator has an attribute called tree_ which stores the entire
# tree structure and allows access to low level attributes. The binary tree
# tree_ is represented as a number of parallel arrays. The i-th element of each
# array holds information about the node `i`. Node 0 is the tree's root. NOTE:
# Some of the arrays only apply to either leaves or split nodes, resp. In this
# case the values of nodes of the other type are arbitrary!
#
# Among those arrays, we have:
# - left_child, id of the left child of the node
# - right_child, id of the right child of the node
# - feature, feature used for splitting the node
# - threshold, threshold value at the node
#
# Using those arrays, we can parse the tree structure:
#n_nodes = estimator.tree_.node_count
n_nodes_ = [t.tree_.node_count for t in estimator.estimators_]
children_left_ = [t.tree_.children_left for t in estimator.estimators_]
children_right_ = [t.tree_.children_right for t in estimator.estimators_]
feature_ = [t.tree_.feature for t in estimator.estimators_]
threshold_ = [t.tree_.threshold for t in estimator.estimators_]
def explore_tree(estimator, n_nodes, children_left,children_right, feature,threshold,
suffix='', print_tree= False, sample_id=0, feature_names=None):
if not feature_names:
feature_names = feature
assert len(feature_names) == X.shape[1], "The feature names do not match the number of features."
# as the depth of each node and whether or not it is a leaf.
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
stack = [(0, -1)] # seed is the root node id and its parent depth
while len(stack) > 0:
node_id, parent_depth = stack.pop()
node_depth[node_id] = parent_depth + 1
# If we have a test node
if (children_left[node_id] != children_right[node_id]):
stack.append((children_left[node_id], parent_depth + 1))
stack.append((children_right[node_id], parent_depth + 1))
else:
is_leaves[node_id] = True
print("The binary tree structure has %s nodes"
% n_nodes)
if print_tree:
print("Tree structure: \n")
for i in range(n_nodes):
if is_leaves[i]:
print("%snode=%s leaf node." % (node_depth[i] * "\t", i))
else:
print("%snode=%s test node: go to node %s if X[:, %s] <= %s else to "
"node %s."
% (node_depth[i] * "\t",
i,
children_left[i],
feature[i],
threshold[i],
children_right[i],
))
print("\n")
print()
# First let's retrieve the decision path of each sample. The decision_path
# method allows to retrieve the node indicator functions. A non zero element of
# indicator matrix at the position (i, j) indicates that the sample i goes
# through the node j.
node_indicator = estimator.decision_path(X_test)
# Similarly, we can also have the leaves ids reached by each sample.
leave_id = estimator.apply(X_test)
# Now, it's possible to get the tests that were used to predict a sample or
# a group of samples. First, let's make it for the sample.
#sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
node_indicator.indptr[sample_id + 1]]
print(X_test[sample_id,:])
print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:
# tabulation = " "*node_depth[node_id] #-> makes tabulation of each level of the tree
tabulation = ""
if leave_id[sample_id] == node_id:
print("%s==> Predicted leaf index \n"%(tabulation))
#continue
if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
threshold_sign = "<="
else:
threshold_sign = ">"
print("%sdecision id node %s : (X_test[%s, '%s'] (= %s) %s %s)"
% (tabulation,
node_id,
sample_id,
feature_names[feature[node_id]],
X_test[sample_id, feature[node_id]],
threshold_sign,
threshold[node_id]))
print("%sPrediction for sample %d: %s"%(tabulation,
sample_id,
estimator.predict(X_test)[sample_id]))
# For a group of samples, we have the following common node.
# sample_ids = [sample_id, 1]
# common_nodes = (node_indicator.toarray()[sample_ids].sum(axis=0) ==
# len(sample_ids))
# common_node_id = np.arange(n_nodes)[common_nodes]
# print("\nThe following samples %s share the node %s in the tree"
# % (sample_ids, common_node_id))
# print("It is %s %% of all nodes." % (100 * len(common_node_id) / n_nodes,))
# for sample_id_ in sample_ids:
# print("Prediction for sample %d: %s"%(sample_id_,
# estimator.predict(X_test)[sample_id_]))
for i,e in enumerate(estimator.estimators_):
print("Tree %d\n"%i)
explore_tree(estimator.estimators_[i],n_nodes_[i],children_left_[i],
children_right_[i], feature_[i],threshold_[i],
suffix=i, sample_id=1, feature_names=["Feature_%d"%i for i in range(X.shape[1])])
print('\n'*2)
结果
Tree 0
The binary tree structure has 127 nodes
[ 9. 1.32658511 -0.08002818 0.88295736 2.24224824 -0.71469736]
Rules used to predict sample 1:
decision id node 0 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 2.4628634452819824)
decision id node 1 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -0.6601005792617798)
decision id node 9 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) <= 0.9460248947143555)
decision id node 10 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > 1.2726516723632812)
decision id node 108 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) > -1.6057568788528442)
decision id node 112 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) > -1.585638165473938)
decision id node 114 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.3288392424583435)
decision id node 116 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 2.0278358459472656)
==> Predicted leaf index
decision id node 118 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 1.0
Tree 1
The binary tree structure has 135 nodes
[ 9. 1.32658511 -0.08002818 0.88295736 2.24224824 -0.71469736]
Rules used to predict sample 1:
decision id node 0 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.7038955688476562)
decision id node 94 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) <= 1.4650013446807861)
decision id node 95 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) <= 1.0235941410064697)
==> Predicted leaf index
decision id node 96 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 1.0
Tree 2
The binary tree structure has 187 nodes
[ 9. 1.32658511 -0.08002818 0.88295736 2.24224824 -0.71469736]
Rules used to predict sample 1:
decision id node 0 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.5484486818313599)
decision id node 124 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > 1.016119122505188)
decision id node 138 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) <= 2.2443997859954834)
decision id node 139 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 2.0519332885742188)
==> Predicted leaf index
decision id node 179 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 1.0
Tree 3
The binary tree structure has 121 nodes
[ 9. 1.32658511 -0.08002818 0.88295736 2.24224824 -0.71469736]
Rules used to predict sample 1:
decision id node 0 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) <= 0.9333831071853638)
decision id node 1 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -1.2789976596832275)
decision id node 9 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.5068763494491577)
decision id node 83 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.336268812417984)
decision id node 89 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 0.05391387641429901)
decision id node 90 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) > -0.34285539388656616)
decision id node 92 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.9664593935012817)
decision id node 94 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) <= 0.46754080057144165)
==> Predicted leaf index
decision id node 95 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 0.0
Tree 4
The binary tree structure has 159 nodes
[ 9. 1.32658511 -0.08002818 0.88295736 2.24224824 -0.71469736]
Rules used to predict sample 1:
decision id node 0 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.673632800579071)
decision id node 130 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) <= 0.9460248947143555)
decision id node 131 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) > -0.413898229598999)
decision id node 143 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) > -1.4878294467926025)
==> Predicted leaf index
decision id node 147 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 1.0
Tree 5
The binary tree structure has 123 nodes
[ 9. 1.32658511 -0.08002818 0.88295736 2.24224824 -0.71469736]
Rules used to predict sample 1:
decision id node 0 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > -0.3723614811897278)
decision id node 18 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.08041483908891678)
decision id node 40 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 1.4040119647979736)
decision id node 86 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.6183018684387207)
decision id node 112 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 0.02021688222885132)
decision id node 113 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) <= -0.7041467428207397)
==> Predicted leaf index
decision id node 114 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 0.0
Tree 6
The binary tree structure has 119 nodes
[ 9. 1.32658511 -0.08002818 0.88295736 2.24224824 -0.71469736]
Rules used to predict sample 1:
decision id node 0 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 2.1899335384368896)
decision id node 1 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > -0.6121580600738525)
decision id node 13 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) <= 2.0935070514678955)
decision id node 14 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 0.21425755321979523)
decision id node 15 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > -0.20205456018447876)
decision id node 19 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) <= 0.9333831071853638)
decision id node 20 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) <= 0.8492134809494019)
decision id node 21 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) <= -0.5752145648002625)
decision id node 22 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.039557620882987976)
decision id node 28 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) > -0.5093963146209717)
decision id node 30 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.8792927265167236)
decision id node 32 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= -0.014746115542948246)
decision id node 33 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > 0.7183631658554077)
==> Predicted leaf index
decision id node 35 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 0.0
Tree 7
The binary tree structure has 129 nodes
[ 9. 1.32658511 -0.08002818 0.88295736 2.24224824 -0.71469736]
Rules used to predict sample 1:
decision id node 0 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) <= 0.9460248947143555)
decision id node 1 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.031008129939436913)
decision id node 35 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 0.22173278033733368)
decision id node 36 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.2683340311050415)
decision id node 54 : (X_test[1, 'Feature_0'] (= 9.0) > 1.5)
decision id node 66 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > 0.9638141393661499)
decision id node 70 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.41362708806991577)
decision id node 72 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) > -1.3964011669158936)
decision id node 82 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.4883882403373718)
decision id node 84 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) <= 1.3384532928466797)
==> Predicted leaf index
decision id node 85 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 0.0
Tree 8
The binary tree structure has 135 nodes
[ 9. 1.32658511 -0.08002818 0.88295736 2.24224824 -0.71469736]
Rules used to predict sample 1:
decision id node 0 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > -0.3356468677520752)
decision id node 24 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) <= 1.3858520984649658)
decision id node 25 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.08041483908891678)
decision id node 37 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 0.2188812792301178)
decision id node 38 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.7280778884887695)
decision id node 68 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) > -1.3525903224945068)
decision id node 70 : (X_test[1, 'Feature_0'] (= 9.0) > 7.0)
==> Predicted leaf index
decision id node 72 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 0.0
Tree 9
The binary tree structure has 121 nodes
[ 9. 1.32658511 -0.08002818 0.88295736 2.24224824 -0.71469736]
Rules used to predict sample 1:
decision id node 0 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) > -0.9643558859825134)
decision id node 40 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > -0.3356468677520752)
decision id node 54 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -0.23724764585494995)
decision id node 62 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) <= 1.3771547079086304)
decision id node 63 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 0.09199776500463486)
decision id node 64 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) > -0.9304192066192627)
decision id node 74 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.6660356521606445)
decision id node 78 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) <= 1.3506155014038086)
==> Predicted leaf index
decision id node 79 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 0.0