失败的历程
pydot+graphviz
tree.export_graphviz(regr_1, out_file=f)
(graph,) = pydot.graph_from_dot_file("./tree.dot")
graph.write_png("./tree.png")
然后就是一直报错找不到dot文件,一直报,一直报,报到人头铁了都。
我甚至取读了pydot.py的源码,报错信息应该来源于一下几行代码:
我甚至根据网上的建议改了下面的函数
然后就什么卵用都没有~~~~
成功的做法
pydotplus+graphviz
过程如下:
第一步:安装相应的工具
pip install pydotplus
pip install graphviz
上面的一般不会出现什么问题,关键是下面这一个graphviz
在自己的机器上也还要有个这玩意,linux安装就比较方便:
yum install graphviz
第二步:整体过程
import pandas as pd
import numpy as np
from sklearn import tree
import graphviz
import pydot
import pydotplus
from sklearn.metrics import mean_squared_error,mean_absolute_error
from sklearn.tree import DecisionTreeRegressor
X_train=pd.read_csv("../state/data/trainData/X_train.csv",index_col=0)
X_test=pd.read_csv("../state/data/trainData/X_test.csv",index_col=0)
Y_train=pd.read_csv("../state/data/trainData/Y_train.csv",index_col=0)
Y_test=pd.read_csv("../state/data/trainData/Y_test.csv",index_col=0)
X_train_array=np.asarray(X_train)
X_test_array=np.asarray(X_test)
Y_train_array=np.asarray(Y_train)
Y_test_array=np.asarray(Y_test)
Y_train_array.reshape(len(Y_train_array),-1)
Y_test_array.reshape(len(Y_test_array),-1)
for dep in [3]:
regr_1 = DecisionTreeRegressor(max_depth=dep)
# criterion = "mse",参数选取的规则,回归的时候可以选mse或者mae,对于分类问题选择Gini系数或者Entropy分别对应CART算法和C4.5s
# splitter = "best",是随机的选取划分点还是随机的选择划分点。
# max_depth = None,树的最大的深度
# min_samples_split = 2,在分割过程中考虑的最少的样本,默认是2
# min_samples_leaf = 1,叶子节点的最少样本
# min_weight_fraction_leaf = 0.,
# max_features = None,
# random_state = None,
# max_leaf_nodes = None,
# min_impurity_decrease = 0.,
# min_impurity_split = None,
# presort = 'deprecated',
# ccp_alpha
regr_1=regr_1.fit(X_train_array, Y_train_array)
y_pre=regr_1.predict(X_test_array)
mse=mean_squared_error(y_pre,Y_test_array)
mae=mean_absolute_error(y_pre,Y_test_array)
print("树的深度",dep," mae: ",mae," mse ",mse)
with open("tree.dot", 'w') as f:
f = tree.export_graphviz(regr_1, out_file=f)
# (graph,) = pydot.graph_from_dot_file("./tree.dot")
# (graph,) = pydotplus.graph_from_dot_file("./tree.dot")
dot_data = tree.export_graphviz(regr_1, out_file=None,
# feature_names=iris.feature_names,
# class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph=pydotplus.graph_from_dot_data(dot_data)
graph.write_png("./tree.png")
#
# graph.save()
# dot.view()
舒服了~