sklearn.tree
模块
sklearn.tree
提供了 决策树(Decision Tree) 模型,支持 分类(Classification)和回归(Regression) 任务,常用于 模式识别、特征选择和可解释性分析。
1. sklearn.tree
主要模型
任务 | 模型 | 适用情况 |
---|---|---|
分类 | DecisionTreeClassifier | 适用于分类任务 |
回归 | DecisionTreeRegressor | 适用于回归任务 |
可视化 | plot_tree() | 绘制决策树结构 |
导出 | export_text() | 以文本格式导出决策树 |
2. 决策树分类
(1) DecisionTreeClassifier
(决策树分类)
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# 加载数据
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
# 训练决策树分类器
model = DecisionTreeClassifier(max_depth=3, random_state=42)
model.fit(X_train, y_train)
# 预测
y_pred = model.predict(X_test)
# 计算准确率
accuracy = model.score(X_test, y_test)
print("准确率:", accuracy)
解释
max_depth=3
限制决策树深度,防止过拟合。random_state=42
保持随机性一致性。score()
计算测试集准确率。
3. 决策树回归
(2) DecisionTreeRegressor
(决策树回归)
from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
# 生成回归数据
X, y = make_regression(n_samples=100, n_features=1, noise=10, random_state=42)
# 训练决策树回归模型
model = DecisionTreeRegressor(max_depth=3)
model.fit(X, y)
# 预测
y_pred = model.predict(X)
# 可视化
plt.scatter(X, y, label="真实数据")
plt.scatter(X, y_pred, color="red", label="决策树预测")
plt.legend()
plt.show()
解释
max_depth=3
限制树的深度,防止过拟合。- 适用于非线性回归任务,但 容易过拟合。
4. DecisionTreeClassifier
& DecisionTreeRegressor
主要参数
DecisionTreeClassifier(criterion="gini", max_depth=None, min_samples_split=2, min_samples_leaf=1, random_state=None)
DecisionTreeRegressor(criterion="squared_error", max_depth=None, min_samples_split=2, min_samples_leaf=1, random_state=None)
参数 | 说明 |
---|---|
criterion | 分类:"gini" 或 "entropy" (基尼系数/信息增益)回归: "squared_error" (默认均方误差) |
max_depth | 树的最大深度(默认 None ,直到所有叶子纯净) |
min_samples_split | 进行分裂的最小样本数(默认 2 ) |
min_samples_leaf | 叶子节点的最小样本数(默认 1 ) |
random_state | 随机种子,保证结果可复现 |
5. 决策树可视化
(3) plot_tree()
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 8))
plot_tree(model, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
plt.show()
解释
plot_tree()
可视化决策树结构,查看分裂情况。filled=True
颜色填充,直观显示类别信息。
(4) export_text()
from sklearn.tree import export_text
tree_rules = export_text(model, feature_names=iris.feature_names)
print(tree_rules)
解释
- 以文本格式导出决策树规则,适用于模型解释。
6. DecisionTreeClassifier
vs. DecisionTreeRegressor
模型 | 适用情况 | 目标变量 |
---|---|---|
DecisionTreeClassifier | 分类任务 | 离散类别 |
DecisionTreeRegressor | 回归任务 | 连续数值 |
示例
from sklearn.linear_model import LinearRegression
reg = LinearRegression().fit(X, y)
tree_reg = DecisionTreeRegressor(max_depth=3).fit(X, y)
print("线性回归预测:", reg.predict(X)[:5])
print("决策树回归预测:", tree_reg.predict(X)[:5])
解释
- 线性回归适用于数值预测,决策树回归适用于非线性数据。
7. 适用场景
- 分类任务(如 垃圾邮件检测、信用评分)。
- 回归任务(如 房价预测、股票趋势分析)。
- 数据可解释性强的场景(决策树可以直接解释决策过程)。
8. 结论
sklearn.tree
提供了分类和回归任务的决策树模型,支持 可视化和文本导出。- 如果 数据是分类问题,可以使用
DecisionTreeClassifier
;如果是回归问题,可以使用DecisionTreeRegressor
。