from sklearn.datasets import load_iris
from sklearn import tree
import pydotplus
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from numpy import *
iris=load_iris()
clf=tree.DecisionTreeClassifier()#决策树模型
clf=clf.fit(iris.data,iris.target)#训练模型
result=clf.predict(iris.data[:5,:])#预测类别
result_prob=clf.predict_proba(iris.data[:5,:])#预测类别对应的概率
print(result)
print(result_prob)
#输出决策树形式的dot文件
with open("iris.dot","w") as f:
f=tree.export_graphviz(clf,out_file=f)
输出结果:
[0 0 0 0 0]
[[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]]
……
#画决策树
dot_data=tree.export_graphviz(clf,out_file=None)
graph=pydotplus.graph_from_d