【scikit-learn】sklearn.tree.DecisionTreeRegressor 类:决策树回归

sklearn.tree.DecisionTreeRegressor(决策树回归)

DecisionTreeRegressorsklearn.tree 提供的 回归任务决策树模型,用于 预测连续数值,适用于 非线性关系的回归问题


1. DecisionTreeRegressor 作用

  • 适用于回归任务(如房价预测、股票走势预测)。
  • 自动选择最优特征进行分裂,适用于 非线性数据
  • 可解释性强,可视化决策路径

2. DecisionTreeRegressor 代码示例

(1) 训练决策树回归模型

from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
import numpy as np

# 生成回归数据
X, y = make_regression(n_samples=100, n_features=1, noise=10, random_state=42)

# 训练决策树回归模型
model = DecisionTreeRegressor(max_depth=3)
model.fit(X, y)

# 预测
X_sorted = np.sort(X, axis=0)  # 排序 X 以便可视化
y_pred = model.predict(X_sorted)

# 可视化
plt.scatter(X, y, label="真实数据")
plt.plot(X_sorted, y_pred, color="red", label="决策树预测")
plt.legend()
plt.show()

解释

  • max_depth=3 限制决策树深度,防止过拟合。
  • 适用于非线性回归任务,但 容易过拟合

3. DecisionTreeRegressor 主要参数

DecisionTreeRegressor(criterion="squared_error", max_depth=None, min_samples_split=2, min_samples_leaf=1, random_state=None)
参数说明
criterion“squared_error”(默认) or “friedman_mse” or “absolute_error”(损失函数)
max_depth最大树深度(默认 None,自动生长)
min_samples_split分裂内部节点的最小样本数(默认 2
min_samples_leaf叶子节点的最小样本数(默认 1
random_state设置随机种子,保证结果可复现

4. 获取特征重要性

feature_importances = model.feature_importances_
print("特征重要性:", feature_importances)

解释

  • feature_importances_ 返回每个特征的重要性(值越大,该特征越重要)。

5. 计算模型性能

from sklearn.metrics import mean_squared_error, r2_score

y_test_pred = model.predict(X)
mse = mean_squared_error(y, y_test_pred)
r2 = r2_score(y, y_test_pred)

print("均方误差 MSE:", mse)
print("决定系数 R²:", r2)

解释

  • MSE(均方误差):值越小,拟合效果越好。
  • R²(决定系数)1 表示完美拟合,0 表示无解释能力。

6. 决策树可视化

(1) plot_tree()

from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 8))
plot_tree(model, filled=True)
plt.show()

解释

  • 可视化决策树结构,查看分裂情况

(2) export_text()

from sklearn.tree import export_text

tree_rules = export_text(model)
print(tree_rules)

解释

  • 以文本格式导出决策树规则

7. DecisionTreeRegressor vs. LinearRegression

模型适用情况目标变量适合数据
LinearRegression线性回归连续值线性数据
DecisionTreeRegressor决策树回归连续值非线性数据

示例

from sklearn.linear_model import LinearRegression

reg = LinearRegression().fit(X, y)
tree_reg = DecisionTreeRegressor(max_depth=3).fit(X, y)

print("线性回归预测:", reg.predict(X)[:5])
print("决策树回归预测:", tree_reg.predict(X)[:5])

解释

  • 线性回归适用于线性数据,决策树适用于非线性数据

8. 适用场景

  • 回归任务(如房价预测、能源消耗预测)。
  • 数据具有非线性关系,线性模型无法有效拟合。
  • 需要可解释性强的模型(决策路径可视化)。

9. 结论

  • DecisionTreeRegressor 适用于回归任务,基于树结构自动选择最优特征进行回归,支持 可视化和文本导出,但 容易过拟合,需要调整 max_depthmin_samples_split 等参数

  • 如果 希望提高泛化能力,可以使用 RandomForestRegressor

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值