#coding:utf-8
from sklearn import tree
import os
from sklearn.datasets import load_iris
from sklearn.externals.six import StringIO
import pydot
def test_decision_tree_classifier():
test = [5.2,4.0,7.1,0.3]
iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data,iris.target)
#简单预测
print str(clf.predict([test]))
print str(clf.predict_proba([test]))
#导出树
with open('iris.dot','w') as f:
f = tree.export.export_graphviz(clf, out_file=f)
def test_decision_tree_regressor():
test = [5.2,4.0,7.1,0.3]
iris = load_iris()
clf = tree.DecisionTreeRegressor()
clf = clf.fit(iris.data,iris.target)
print str(clf.predict([test]))
if __name__ == '__main__':
test_decision_tree_classifier()
test_decision_tree_regressor()
Python 决策树
最新推荐文章于 2024-05-23 23:44:50 发布