用sklearn实现决策树的模型构建
简介
具体查询官方文档
DecisionTrees可以用于分类问题和预测问题。
决策树的优点是模型具有可读性,分类速度快。
决策树学习通常包括三个步骤:特征选择、决策树的生成和决策树的修剪。
常见的决策树算法:ID3,C4.5,CART
环境配置
sklearn更新最新版本时,显示“无法定位程序输入点……”之类的问题。我的处理流程:
- 将DLLs中的libssl-1_1-x64复制到library/bin中(bin中的文件被替换,换成旧的时间)
- 在terminal中输入conda update python
- 将bin中的文件换回原来更新时间的版本
- 使用conda install openssl --force-reinstall重新安装OpenSSL以确保它是最新的。
- 更新sklearn
conda update scikit-learn
应用
流程
- 导入相关的函数
- 读取数据,训练数据X,y
- 定义模型
- 拟合模型
- 对验证数据进行预测
分类问题
from sklearn import tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# X, y = load_iris(return_X_y=True)
# iris = load_iris(return_X_y=False)
# model = tree.DecisionTreeClassifier()
# model.fit(X, y)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
estimator.fit(X_train, y_train)
决策树的可视化
- 方法一:用tree.plot_tree
tree.plot_tree(model.fit(iris.data, iris.target))
- 方法二:利用Graphviz库,
conda install python-graphviz
import graphviz
dot_data = tree.export_graphviz(model, out_file=None,
... feature_names=iris.feature_names,
... class_names=iris.target_names,
... filled=True, rounded=True,
... special_characters=True)
graph = graphviz.Source(dot_data)
graph.render('iris2')
graph
- 方法三:输出为文本export_text
from sklearn.tree.export import export_text
r = export_text(model, feature_names=iris['feature_names'])
print(r)
回归问题
from sklearn.tree import DecisionTreeRegressor
X = ___
y = ___
model = DecisionTreeRegressor()
model.fit(X, y)
sor
X = ___
y = ___
model = DecisionTreeRegressor()
model.fit(X, y)