参考scikit-learn v0.19.0, 此处使用鸢尾花数据实现决策树
#!/usr/bin/python
# -*- coding:utf-8 -*-
import numpy as np
import pandas as pd
def main():
#预处理
from sklearn.datasets import load_iris
iris = load_iris()
#print iris
#print(len(iris['data']))
from sklearn.cross_validation import train_test_split
train_data, test_data, train_target, test_target = train_test_split(iris.data, iris.target, test_size=0.2, random_state=1)
#调model
from sklearn import tree
clf = tree.DecisionTreeClassifier(criterion = 'entropy')
clf.fit(train_data, train_target)
y_pred = clf.predict(test_data)
#验证verify
from sklearn import metrics
print(metrics.accuracy_score(y_true=test_target, y_pred=y_pred))
print(metrics.confusion_matrix(y_true=test_target, y_pred=y_pred))
with open("./tree.dot", "w") as fw:
tree.export_graphviz(clf, out_file=fw)
if __name__ == "__main__":
main()
输出为:
/usr/local/lib/python2.7/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
"This module will be removed in 0.20.", DeprecationWarning)
0.966666666667
[[11 0 0]
[ 0 12 1]
[ 0 0 6]]