LINUX 系统:
1. 首先需要安装Anaconda,Graphviz,pydot
Anaconda安装好以后,很多package已经内置其中,这里我们需要做的是安装 graphviz 和 pydot
记得这里用 Anaconda 的 pip
pip install graphviz
pip install pydot
接下来,如果之前没有把Anaconda的pakcage放进PATH中,这里就应该重新编辑下 ~/.bashrc,把graphviz 和 pydot 所在的路径放进去
export PATH="/root/anaconda/install/lib/python3.6/site-packages/graphviz:$PATH"
这里是我在centos上的路径
2. 在 Jupyter notebook 上面实现直接的 decision tree 可视化
from sklearn.tree import DecisionTreeRegressor, export_graphviz
from sklearn.externals.six import StringIO
import pydot
import graphviz
from IPython.display import Image
clf = DecisionTreeRegressor(max_depth =4)
clf = clf.fit(X_train, y_train)
dot_data = StringIO()
export_graphviz(clf, feature_names=ls_col[:-1], out_file=dot_data)
(graph,) = pydot.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
到这一步,如果没什么问题,就可以看到图形了
3. 决策树的规则抽取
def get_code(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
dent_s = ' '
def recurse(left, right, threshold, features, node, dent):
if (threshold[node] != -2):
print(dent+"current value: " + str(value[node]))
print(dent+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
if left[node] != -1:
dent_new = dent_s+dent
recurse (left, right, threshold, features,left[node],dent_new)
print(dent+"} else {")
if right[node] != -1:
dent_new = dent_s+dent
recurse (left, right, threshold, features,right[node],dent_new)
print(dent+"}")
else:
print(dent+"current value: " + str(value[node]))
recurse(left, right, threshold, features, 0, '')
接着上面的例子,则调用上述的规则展示函数:
get_code(clf, 变量名列表)
最后的结果展示如下:
…current value: [[ 0.00847842]] if ( potential1_index_3 <= -46.3385009766 ) { current value: [[ 0.01188856]] if ( factorQua1_index_3 <= 1.16363739967 ) { current value: [[ 0.01749776]] if ( potential_index_3 <= 0.000233492843108 ) { current value: [[ 0.01443796]] if ( chgOC_index_3 <= -0.00361792766489 ) { current value: [[ 0.02022273]] } else { current value: [[ 0.00757697]] } } else { current value: [[ 0.03464071]] if ( quaOnIndexR3_3 <= 0.158238857985 ) { current value: [[ 0.03379166]] } else { current value: [[ 0.09789488]] } } } else {
WINDOWS 系统:
1、安装Anaconda
2、下载安装graphviz-XX.msi
url:http://www.graphviz.org/
安装后需要手动把graphviz的安装目录下的bin/文件夹放进环境变量里面
这里简单说下windows 10的环境变量设置:
文件夹里面直接右键点击 此电脑 -》高级系统设置 -》高级
里面直接设置环境变量
如果只是自己使用的话,就不用放在系统环境变量里面
2. 安装pydot
直接使用Anaconda里面的pip
这里提醒的是Anaconda需要放进环境变量里面
pip install pydot
3. 重启jupyter notebook
import os; print(os.environ["PATH"])
观察一下有没有 graphviz 的路径
接下来就可以直接测试了