关于分类情形,参见GBDT-分类篇。
Gradient Boost的算法流程
LS_TreeBoost
LAD_TreeBoost
sklearn源码解读
sklearn.ensemble.GradientBoostingRegressor(loss=’ls’, learning_rate=0.1, n_estimators=100, subsample=1.0, criterion=’friedman_mse’, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_depth=3, min_impurity_decrease=0.0, min_impurity_split=None, init=None, random_state=None, max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None, warm_start=False, presort=’auto’, validation_fraction=0.1, n_iter_no_change=None, tol=0.0001)
loss : {‘ls’, ‘lad’, ‘huber’, ‘quantile’} | loss function to be optimized;
learning_rate : (default=0.1) | shrinks the contribution of each tree by learning_rate;
n_estimators : the number of boosting stages to perform;
subsample : the fraction of samples to be used for fitting the individual base learners;
criterion : {friedman_mse, mse, mae}the function to measure the quality of a split;
- LeastSquaresError
class LeastSquaresError(RegressionLossFunction):
"""
Loss function for least squares (LS) estimation.
Terminal regions need not to be updated for least squares.
"""
def init_estimator(self):
''' 初始化F0 '''
return MeanEstimator()
def __call__(self, y, pred, sample_weight=None):
''' 计算当前的loss '''
if sample_weight is None:
return np.mean((y - pred.ravel()) ** 2.0)
else:
return (1.0 / sample_weight.sum() *
np.sum(sample_weight * ((y - pred.ravel()) ** 2.0)))
def negative_gradient(self, y, pred, **kargs):
''' 计算负梯度 '''
return y - pred.ravel()
def update_terminal_regions(self, tree, X, y, residual, y_pred,
sample_weight, sample_mask,
learning_rate=0.1, k=0):
''' 更新Fm(x) '''
y_pred[:, k] += learning_rate * tree.predict(X).ravel()
其中,用于初始化F0的MeanEstimator类如下:
class MeanEstimator:
def fit(self, X, y, sample_weight=None):
''' 对于mse,使用均值作为F0的初始值 '''
if sample_weight is None:
self.mean = np.mean(y)
else:
self.mean = np.average(y, weights=sample_weight)
def predict(self, X):
""" Predict labels """
check_is_fitted(self, 'mean')
y = np.empty((X.shape[0], 1), dtype=np.float64)
y.fill(self.mean)
return y
算例
x | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
y | 5.56 | 5.70 | 5.91 | 6.40 | 6.80 | 7.05 | 8.90 | 8.70 | 9.00 | 9.05 |
1