打印随机森林模型预测样本的决策路径

背景

通过打印树模型对测试样本的决策路径,实现树模型对预测样本的可解释性。

实现

import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier

X, y = make_classification(n_samples=1000,
                           n_features=6,
                           n_informative=3,
                           n_classes=2,
                           random_state=0,
                           shuffle=False)

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

estimator = RandomForestClassifier(n_estimators=10,
                               random_state=0)
estimator.fit(X_train, y_train)

# The decision estimator has an attribute called tree_  which stores the entire
# tree structure and allows access to low level attributes. The binary tree
# tree_ is represented as a number of parallel arrays. The i-th element of each
# array holds information about the node `i`. Node 0 is the tree's root. NOTE:
# Some of the arrays only apply to either leaves or split nodes, resp. In this
# case the values of nodes of the other type are arbitrary!
#
# Among those arrays, we have:
#   - left_child, id of the left child of the node
#   - right_child, id of the right child of the node
#   - feature, feature used for splitting the node
#   - threshold, threshold value at the node
#

# Using those arrays, we can parse the tree structure:

#n_nodes = estimator.tree_.node_count
n_nodes_ = [t.tree_.node_count for t in estimator.estimators_]
children_left_ = [t.tree_.children_left for t in estimator.estimators_]
children_right_ = [t.tree_.children_right for t in estimator.estimators_]
feature_ = [t.tree_.feature for t in estimator.estimators_]
threshold_ = [t.tree_.threshold for t in estimator.estimators_]
def explore_tree(estimator, n_nodes, children_left,children_right, feature,threshold,
                suffix='', print_tree= False, sample_id=0, feature_names=None):

    if not feature_names:
        feature_names = feature


    assert len(feature_names) == X.shape[1], "The feature names do not match the number of features."

    # as the depth of each node and whether or not it is a leaf.
    node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
    is_leaves = np.zeros(shape=n_nodes, dtype=bool)

    stack = [(0, -1)]  # seed is the root node id and its parent depth
    while len(stack) > 0:
        node_id, parent_depth = stack.pop()
        node_depth[node_id] = parent_depth + 1

        # If we have a test node
        if (children_left[node_id] != children_right[node_id]):
            stack.append((children_left[node_id], parent_depth + 1))
            stack.append((children_right[node_id], parent_depth + 1))
        else:
            is_leaves[node_id] = True

    print("The binary tree structure has %s nodes"
          % n_nodes)
    if print_tree:
        print("Tree structure: \n")
        for i in range(n_nodes):
            if is_leaves[i]:
                print("%snode=%s leaf node." % (node_depth[i] * "\t", i))
            else:
                print("%snode=%s test node: go to node %s if X[:, %s] <= %s else to "
                      "node %s."
                      % (node_depth[i] * "\t",
                         i,
                         children_left[i],
                         feature[i],
                         threshold[i],
                         children_right[i],
                         ))
            print("\n")
        print()

    # First let's retrieve the decision path of each sample. The decision_path
    # method allows to retrieve the node indicator functions. A non zero element of
    # indicator matrix at the position (i, j) indicates that the sample i goes
    # through the node j.

    node_indicator = estimator.decision_path(X_test)

    # Similarly, we can also have the leaves ids reached by each sample.

    leave_id = estimator.apply(X_test)

    # Now, it's possible to get the tests that were used to predict a sample or
    # a group of samples. First, let's make it for the sample.

    #sample_id = 0
    node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                        node_indicator.indptr[sample_id + 1]]

    print(X_test[sample_id,:])

    print('Rules used to predict sample %s: ' % sample_id)
    for node_id in node_index:
        # tabulation = " "*node_depth[node_id] #-> makes tabulation of each level of the tree
        tabulation = ""
        if leave_id[sample_id] == node_id:
            print("%s==> Predicted leaf index \n"%(tabulation))
            #continue

        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("%sdecision id node %s : (X_test[%s, '%s'] (= %s) %s %s)"
              % (tabulation,
                 node_id,
                 sample_id,
                 feature_names[feature[node_id]],
                 X_test[sample_id, feature[node_id]],
                 threshold_sign,
                 threshold[node_id]))
    print("%sPrediction for sample %d: %s"%(tabulation,
                                          sample_id,
                                          estimator.predict(X_test)[sample_id]))

    # For a group of samples, we have the following common node.
#     sample_ids = [sample_id, 1]
#     common_nodes = (node_indicator.toarray()[sample_ids].sum(axis=0) ==
#                     len(sample_ids))

#     common_node_id = np.arange(n_nodes)[common_nodes]

#     print("\nThe following samples %s share the node %s in the tree"
#           % (sample_ids, common_node_id))
#     print("It is %s %% of all nodes." % (100 * len(common_node_id) / n_nodes,))

#     for sample_id_ in sample_ids:
#         print("Prediction for sample %d: %s"%(sample_id_,
#                                           estimator.predict(X_test)[sample_id_]))
for i,e in enumerate(estimator.estimators_):
    print("Tree %d\n"%i)
    explore_tree(estimator.estimators_[i],n_nodes_[i],children_left_[i],
                 children_right_[i], feature_[i],threshold_[i],
                suffix=i, sample_id=1, feature_names=["Feature_%d"%i for i in range(X.shape[1])])
    print('\n'*2)

结果

Tree 0

The binary tree structure has 127 nodes
[ 9.          1.32658511 -0.08002818  0.88295736  2.24224824 -0.71469736]
Rules used to predict sample 1: 
decision id node 0 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 2.4628634452819824)
decision id node 1 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -0.6601005792617798)
decision id node 9 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) <= 0.9460248947143555)
decision id node 10 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > 1.2726516723632812)
decision id node 108 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) > -1.6057568788528442)
decision id node 112 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) > -1.585638165473938)
decision id node 114 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.3288392424583435)
decision id node 116 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 2.0278358459472656)
==> Predicted leaf index 

decision id node 118 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 1.0



Tree 1

The binary tree structure has 135 nodes
[ 9.          1.32658511 -0.08002818  0.88295736  2.24224824 -0.71469736]
Rules used to predict sample 1: 
decision id node 0 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.7038955688476562)
decision id node 94 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) <= 1.4650013446807861)
decision id node 95 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) <= 1.0235941410064697)
==> Predicted leaf index 

decision id node 96 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 1.0



Tree 2

The binary tree structure has 187 nodes
[ 9.          1.32658511 -0.08002818  0.88295736  2.24224824 -0.71469736]
Rules used to predict sample 1: 
decision id node 0 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.5484486818313599)
decision id node 124 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > 1.016119122505188)
decision id node 138 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) <= 2.2443997859954834)
decision id node 139 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 2.0519332885742188)
==> Predicted leaf index 

decision id node 179 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 1.0



Tree 3

The binary tree structure has 121 nodes
[ 9.          1.32658511 -0.08002818  0.88295736  2.24224824 -0.71469736]
Rules used to predict sample 1: 
decision id node 0 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) <= 0.9333831071853638)
decision id node 1 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -1.2789976596832275)
decision id node 9 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.5068763494491577)
decision id node 83 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.336268812417984)
decision id node 89 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 0.05391387641429901)
decision id node 90 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) > -0.34285539388656616)
decision id node 92 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.9664593935012817)
decision id node 94 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) <= 0.46754080057144165)
==> Predicted leaf index 

decision id node 95 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 0.0



Tree 4

The binary tree structure has 159 nodes
[ 9.          1.32658511 -0.08002818  0.88295736  2.24224824 -0.71469736]
Rules used to predict sample 1: 
decision id node 0 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.673632800579071)
decision id node 130 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) <= 0.9460248947143555)
decision id node 131 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) > -0.413898229598999)
decision id node 143 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) > -1.4878294467926025)
==> Predicted leaf index 

decision id node 147 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 1.0



Tree 5

The binary tree structure has 123 nodes
[ 9.          1.32658511 -0.08002818  0.88295736  2.24224824 -0.71469736]
Rules used to predict sample 1: 
decision id node 0 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > -0.3723614811897278)
decision id node 18 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.08041483908891678)
decision id node 40 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 1.4040119647979736)
decision id node 86 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.6183018684387207)
decision id node 112 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 0.02021688222885132)
decision id node 113 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) <= -0.7041467428207397)
==> Predicted leaf index 

decision id node 114 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 0.0



Tree 6

The binary tree structure has 119 nodes
[ 9.          1.32658511 -0.08002818  0.88295736  2.24224824 -0.71469736]
Rules used to predict sample 1: 
decision id node 0 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 2.1899335384368896)
decision id node 1 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > -0.6121580600738525)
decision id node 13 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) <= 2.0935070514678955)
decision id node 14 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 0.21425755321979523)
decision id node 15 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > -0.20205456018447876)
decision id node 19 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) <= 0.9333831071853638)
decision id node 20 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) <= 0.8492134809494019)
decision id node 21 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) <= -0.5752145648002625)
decision id node 22 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.039557620882987976)
decision id node 28 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) > -0.5093963146209717)
decision id node 30 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.8792927265167236)
decision id node 32 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= -0.014746115542948246)
decision id node 33 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > 0.7183631658554077)
==> Predicted leaf index 

decision id node 35 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 0.0



Tree 7

The binary tree structure has 129 nodes
[ 9.          1.32658511 -0.08002818  0.88295736  2.24224824 -0.71469736]
Rules used to predict sample 1: 
decision id node 0 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) <= 0.9460248947143555)
decision id node 1 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.031008129939436913)
decision id node 35 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 0.22173278033733368)
decision id node 36 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.2683340311050415)
decision id node 54 : (X_test[1, 'Feature_0'] (= 9.0) > 1.5)
decision id node 66 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > 0.9638141393661499)
decision id node 70 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.41362708806991577)
decision id node 72 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) > -1.3964011669158936)
decision id node 82 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.4883882403373718)
decision id node 84 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) <= 1.3384532928466797)
==> Predicted leaf index 

decision id node 85 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 0.0



Tree 8

The binary tree structure has 135 nodes
[ 9.          1.32658511 -0.08002818  0.88295736  2.24224824 -0.71469736]
Rules used to predict sample 1: 
decision id node 0 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > -0.3356468677520752)
decision id node 24 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) <= 1.3858520984649658)
decision id node 25 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.08041483908891678)
decision id node 37 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 0.2188812792301178)
decision id node 38 : (X_test[1, 'Feature_3'] (= 0.8829573603562209) > 0.7280778884887695)
decision id node 68 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) > -1.3525903224945068)
decision id node 70 : (X_test[1, 'Feature_0'] (= 9.0) > 7.0)
==> Predicted leaf index 

decision id node 72 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 0.0



Tree 9

The binary tree structure has 121 nodes
[ 9.          1.32658511 -0.08002818  0.88295736  2.24224824 -0.71469736]
Rules used to predict sample 1: 
decision id node 0 : (X_test[1, 'Feature_5'] (= -0.7146973587899221) > -0.9643558859825134)
decision id node 40 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) > -0.3356468677520752)
decision id node 54 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -0.23724764585494995)
decision id node 62 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) <= 1.3771547079086304)
decision id node 63 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) <= 0.09199776500463486)
decision id node 64 : (X_test[1, 'Feature_2'] (= -0.08002817952064323) > -0.9304192066192627)
decision id node 74 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > 0.6660356521606445)
decision id node 78 : (X_test[1, 'Feature_1'] (= 1.3265851104352575) <= 1.3506155014038086)
==> Predicted leaf index 

decision id node 79 : (X_test[1, 'Feature_4'] (= 2.2422482391211678) > -2.0)
Prediction for sample 1: 0.0

原文地址

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值