XGBoost梯度提升
结构化数据最精确的建模技术。
在本节课中,我们将学习如何使用梯度增强来构建和优化模型。这种方法在Kaggle的许多竞争中占据主导地位,并在各种数据集上获得了最先进的结果。
本课程所需数据集夸克网盘下载链接:https://pan.quark.cn/s/9b4e9a1246b2 提取码:uDzP
文章目录
1、简介
我们已经使用随机森林方法进行了预测,该方法仅通过对许多决策树的预测进行平均就可以获得比单个决策树更好的性能。
我们把随机森林方法称为“集成方法”。根据定义,集成方法结合了几种模型的预测(例如,在随机森林的情况下,若干棵树)。
接下来,我们将学习另一种称为梯度增强的集成方法。
2、梯度提升
梯度提升是一种通过循环迭代将模型添加到集成中的方法。
它首先使用单个模型初始化集合,这个模型的预测可能非常天真。(即使它的预测非常不准确,随后增加的集合将解决这些错误。)
然后,我们开始循环:
- 它首先用一个模型初始化集成,该模型的预测可能相当天真。(即使它的预测非常不准确,后续的补充将解决这些错误。)
- 然后,我们开始这个循环:首先,我们使用当前集成来为数据集中的每个观察结果生成预测。
- 为了做出预测,我们将所有模型的预测相加。这些预测用于计算损失函数(例如,均方误差)。
- 然后,我们使用损失函数来拟合一个将被添加到集成中的新模型。具体地说,我们确定模型参数,以便将这个新模型添加到集成中将减少损失。(注:“梯度增强”中的“梯度”指的是我们将在损失函数上使用梯度下降来确定新模型中的参数。)
- 最后,我们将新的模型加入到集成中,并且… 重复!
3、举例
在本例中,我们将使用XGBoost库。XGBoost是extreme gradient boost的缩写,它是梯度增强的一种实现,还有几个额外的特性侧重于性能和速度。
(Scikit-learn有另一个梯度增强版本,但XGBoost有一些技术优势。)
在下一个代码单元中,我们将导入用于XGBoost (XGBoost . xgbregressor)的scikit-learn API。
这使我们能够像在scikit-learn中一样构建和适应一个模型。正如您将在输出中看到的,XGBRegressor类有许多可调参数——您很快就会了解这些参数!
In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
# 加载数据
data = pd.read_csv('../input/melbourne-housing-snapshot/melb_data.csv')
# 选择预测子集
cols_to_use = ['Rooms', 'Distance', 'Landsize', 'BuildingArea', 'YearBuilt']
X = data[cols_to_use]
# 选择目标
y = data.Price
# 将数据分成训练集和验证集
X_train, X_valid, y_train, y_valid = train_test_split(X, y)
在本例中,您将使用 XGBoost 库。XgBoost 代表了极限梯度提升,这是一个梯度提升的实现,其中包含了一些侧重于性能和速度的额外特性。(Scikit-learn 有另一个版本的梯度提升,但 XgBoost 有一些技术优势。)
在下一个代码单元中,我们为 XGBoost (XGBoost.XGBRegressor
)导入 scikit-learn API。这使我们能够建立和适应一个模型,就像我们在 scikit-learn 中所做的那样。正如您将在输出中看到的,XGBRegressor
类有许多可调参数——您很快就会了解这些参数!
In [2]:
from xgboost import XGBRegressor
my_model = XGBRegressor()
my_model.fit(X_train, y_train)
Out[2]:
XGBRegressor(base_score=0.5, booster='gbtree', callbacks=None,
colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1,
early_stopping_rounds=None, enable_categorical=False,
eval_metric=None, gamma=0, gpu_id=-1, grow_policy='depthwise',
importance_type=None, interaction_constraints='',
learning_rate=0.300000012, max_bin=256, max_cat_to_onehot=4,
max_delta_step=0, max_depth=6