from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
from sklearn import datasets
from sklearn import cross_validation
import pydot
def load_data():
iris=datasets.load_iris()
X_train=iris.data
y_train=iris.target
return cross_validation.train_test_split(X_train,y_train,test_size=0.25,random_state=0,stratify=y_train)
X_train,X_test,y_train,y_test=load_data()
clf=DecisionTreeClassifier()
clf.fit(X_train,y_train)
export_graphviz(clf, out_file = 'tree.dot', rounded = True, precision = 1)
(graph, ) = pydot.graph_from_dot_file('tree.dot')
graph.write_png('tree.png');
决策树——决策图
最新推荐文章于 2024-08-04 21:08:26 发布