python 决策树规则提取_使用sklearn python通过决策树提取数据点的规则路径

importnumpyasnpfromsklearn.model_selectionimporttrain_test_splitfromsklearn.datasetsimportload_irisfromsklearn.treeimportDecisionTreeClassifieriris=load_iris()X=iris.data

y=iris.target

X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=0)estimator=DecisionTreeClassifier(max_leaf_nodes=3,random_state=0)estimator.fit(X_train,y_train)# The decision estimator has an attribute called tree_ which stores the entire# tree structure and allows access to low level attributes. The binary tree# tree_ is represented as a number of parallel arrays. The i-th element of each# array holds information about the node `i`. Node 0 is the tree's root. NOTE:# Some of the arrays only apply to either leaves or split nodes, resp. In this# case the values of nodes of the other type are arbitrary!## Among those arrays, we have:# - left_child, id of the left child of the node# - right_child, id of the right child of the node# - feature, feature used for splitting the node# - threshold, threshold value at the noden_nodes=estimator.tree_.node_count

children_left=estimator.tree_.children_left

children_right=estimator.tree_.children_right

feature=estimator.tree_.feature

threshold=estimator.tree_.threshold# The tree structure can be traversed to compute various properties such# as the depth of each node and whether or not it is a leaf.node_depth=np.zeros(shape=n_nodes,dtype=np.int64)is_leaves=np.zeros(shape=n_nodes,dtype=bool)stack=[(0,-1)]# seed is the root node id and its parent depthwhilelen(stack)>0:node_id,parent_depth=stack.pop()node_depth[node_id]=parent_depth+1# If we have a test nodeif(children_left[node_id]!=children_right[node_id]):stack.append((children_left[node_id],parent_depth+1))stack.append((children_right[node_id],parent_depth+1))else:is_leaves[node_id]=Trueprint("The binary tree structure has %s nodes and has ""the following tree structure:"%n_nodes)foriinrange(n_nodes):ifis_leaves[i]:print("%snode=%s leaf node."%(node_depth[i]*"\t",i))else:print("%snode=%s test node: go to node %s if X[:, %s] <= %s else to ""node %s."%(node_depth[i]*"\t",i,children_left[i],feature[i],threshold[i],children_right[i],))print()# First let's retrieve the decision path of each sample. The decision_path# method allows to retrieve the node indicator functions. A non zero element of# indicator matrix at the position (i, j) indicates that the sample i goes# through the node j.node_indicator=estimator.decision_path(X_test)# Similarly, we can also have the leaves ids reached by each sample.leave_id=estimator.apply(X_test)# Now, it's possible to get the tests that were used to predict a sample or# a group of samples. First, let's make it for the sample.# HERE IS WHAT YOU WANTsample_id=0node_index=node_indicator.indices[node_indicator.indptr[sample_id]:node_indicator.indptr[sample_id+1]]print('Rules used to predict sample %s: '%sample_id)fornode_idinnode_index:ifleave_id[sample_id]==node_id:# "print("decision id node %s : (X[%s, %s] (= %s) %s %s)"%(node_id,sample_id,feature[node_id],X_test[sample_id,feature[node_id]],#

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值