0、基础构建
利用sklearn
的鸢尾属植物数据编写简单决策树。
首先是获取数据(干净数据不需要处理),然后需要对规则进行可视化。这一步需要配置 搭建Graphviz
环境
1、鸢尾属植物数据集描述
这个数据集大家肯定熟悉的不能在熟悉了, 但是这三类植物图片不是所有人看过的(笔者也是)。所以笔者将几个预测目标的图片汇总如下:
从图片就可以看出 Iris Setosa(山鸢尾)
和 Iris Versicolour(变色鸢尾)
与
Iris Virginica(维吉尼亚鸢尾)
在花瓣的长宽上区别明显。
2、建模
# 加载数据 鸢尾属植物数据集
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz # 可视化树
import pydotplus
def Get_irir():
iris = load_iris()
return iris.data[:, 2:], iris.target
def Tree_graph(decision_tree, out_root):
# 将决策树可视化,输出pdf(需要配置graphivz环境)
iris = load_iris()
f_names = iris.feature_names[2: ]
class_name = iris.target_names
dot_data = export_graphviz(decision_tree, out_file = None,
feature_names = f_names,
class_names = class_name, special_characters = True,
rounded = True, filled = True )
graph = pydotplus.graph_from_dot_data(daot_data)
return graph.write_pdf(out_root)
## 简单的进行2层树深拟合
if __name__ == '__main__':
x, y = Get_irir()
tree_clf = DecisionTreeClassifier(max_depth = 2)
tree_clf.fit(x, y)
## 决策树可视化 ##
Tree_graph(tree_clf)
## 预测
print(tree_clf.predict([[5, 1.5]]))
print(tree_clf.predict_proba([[5, 1.5]]))
3、决策树预测结果如下
[1]
# Iris-Setosa(0/54)为 0%,Iris-Versicolor为 (49/54),即 90.7%, Iris-Virginica (5/54),即 9.3 %。
[[0. 0.90740741 0.09259259]]
规则如下