决策树
正好,我们专业数学基础与机器学习这两门课都有关于决策树的内容。在本篇文
章中,主要是讲最普通的决策树———基于信息增益的决策树。
首先介绍一下决策树与算法步骤;之后为方便理解,给个例题简单算一算;最
后,用Python自带的包,简单在电脑上实现一下。
1、简单介绍与算法步骤:
2、例题:
3、Python实现:
注:我这个数据集用的是Python自带的莺尾花数据集,而且如果想要可视化的话
,需要安装pydotplus与graphviz环境。
代码如下:
from sklearn import tree#决策树分类器
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris #导入数据集
import pydotplus
import os
os.environ["PATH"] += os.pathsep + 'E:/graphviz/bin'
iris = load_iris()
# 解析数据
x=iris.data
y=iris.target.reshape(-1,1)
# 划分训练集和测试集
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.3)
module = tree.DecisionTreeClassifier(criterion='gini',max_depth=3)
module.fit(x_train, y_train.ravel())
dot_data = tree.export_graphviz(module, out_file=None)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf("iris3.pdf")
Z = module.predict(x_test)
plt.plot(x_test[0:45],y_test[0:45],c = 'b')
plt.plot(x_test[0:45],Z[0:45],c = 'r')
plt.show()
步骤与结果如下: