sklearn.tree.DecisionTreeRegressor
(决策树回归)
DecisionTreeRegressor
是 sklearn.tree
提供的 回归任务决策树模型,用于 预测连续数值,适用于 非线性关系的回归问题。
1. DecisionTreeRegressor
作用
- 适用于回归任务(如房价预测、股票走势预测)。
- 自动选择最优特征进行分裂,适用于 非线性数据。
- 可解释性强,可视化决策路径。
2. DecisionTreeRegressor
代码示例
(1) 训练决策树回归模型
from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
import numpy as np
# 生成回归数据
X, y = make_regression(n_samples=100, n_features=1, noise=10, random_state=42)
# 训练决策树回归模型
model = DecisionTreeRegressor(max_depth=3)
model.fit(X, y)
# 预测
X_sorted = np.sort(X, axis=0) # 排序 X 以便可视化
y_pred = model.predict(X_sorted)
# 可视化
plt.scatter(X, y, label="真实数据")
plt.plot(X_sorted, y_pred, color="red", label="决策树预测")
plt.legend()
plt.show()
解释
max_depth=3
限制决策树深度,防止过拟合。- 适用于非线性回归任务,但 容易过拟合。
3. DecisionTreeRegressor
主要参数
DecisionTreeRegressor(criterion="squared_error", max_depth=None, min_samples_split=2, min_samples_leaf=1, random_state=None)
参数 | 说明 |
---|---|
criterion | “squared_error”(默认) or “friedman_mse” or “absolute_error”(损失函数) |
max_depth | 最大树深度(默认 None ,自动生长) |
min_samples_split | 分裂内部节点的最小样本数(默认 2 ) |
min_samples_leaf | 叶子节点的最小样本数(默认 1 ) |
random_state | 设置随机种子,保证结果可复现 |
4. 获取特征重要性
feature_importances = model.feature_importances_
print("特征重要性:", feature_importances)
解释
feature_importances_
返回每个特征的重要性(值越大,该特征越重要)。
5. 计算模型性能
from sklearn.metrics import mean_squared_error, r2_score
y_test_pred = model.predict(X)
mse = mean_squared_error(y, y_test_pred)
r2 = r2_score(y, y_test_pred)
print("均方误差 MSE:", mse)
print("决定系数 R²:", r2)
解释
- MSE(均方误差):值越小,拟合效果越好。
- R²(决定系数):
1
表示完美拟合,0
表示无解释能力。
6. 决策树可视化
(1) plot_tree()
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 8))
plot_tree(model, filled=True)
plt.show()
解释
- 可视化决策树结构,查看分裂情况。
(2) export_text()
from sklearn.tree import export_text
tree_rules = export_text(model)
print(tree_rules)
解释
- 以文本格式导出决策树规则。
7. DecisionTreeRegressor
vs. LinearRegression
模型 | 适用情况 | 目标变量 | 适合数据 |
---|---|---|---|
LinearRegression | 线性回归 | 连续值 | 线性数据 |
DecisionTreeRegressor | 决策树回归 | 连续值 | 非线性数据 |
示例
from sklearn.linear_model import LinearRegression
reg = LinearRegression().fit(X, y)
tree_reg = DecisionTreeRegressor(max_depth=3).fit(X, y)
print("线性回归预测:", reg.predict(X)[:5])
print("决策树回归预测:", tree_reg.predict(X)[:5])
解释
- 线性回归适用于线性数据,决策树适用于非线性数据。
8. 适用场景
- 回归任务(如房价预测、能源消耗预测)。
- 数据具有非线性关系,线性模型无法有效拟合。
- 需要可解释性强的模型(决策路径可视化)。
9. 结论
-
DecisionTreeRegressor
适用于回归任务,基于树结构自动选择最优特征进行回归,支持 可视化和文本导出,但 容易过拟合,需要调整max_depth
、min_samples_split
等参数。 -
如果 希望提高泛化能力,可以使用
RandomForestRegressor
。