gbdt-源码分析

1. 源码分析

源码阅读的是Python著名的库sklearn里的代码。sklearn里gbdt(sklearn/ensemble/gradient_boosting.py)相关的类有 GradientBoostingRegressorGradientBoostingClassifier,共同的父类是BaseGradientBoosting.boost的基本实现在BaseGradientBoosting里。主要的几个参数是(更详细的看sklearn的文档):

loss : {
  'ls', 'lad', 'huber', 'quantile'}, optional (default='ls')
        loss function to be optimized. 'ls' refers to least squares regression. 'lad' (least absolute deviation) is a highly robust loss function solely based on order information of the input variables. 'huber' is a combination of the two. 'quantile' allows quantile regression (use `alpha` to specify the quantile).

learning_rate : float, optional (default=0.1)
        learning rate shrinks the contribution of each tree by `learning_rate`.There is a trade-off between learning_rate and n_estimators.

n_estimators : int (default=100)
        The number of boosting stages to perform. Gradient boosting is fairly robust to over-fitting so a large number usually results in better performance.

max_depth : integer, optional (default=3)
        maximum depth of the individual regression estimators. The maximum depth limits the number of nodes in the tree. Tune this parameter for best performance; the best value depends on the interaction of the input variables.

....

init : BaseEstimator, None, optional (default=None)

....

BaseGradientBoosting主要的相关函数是fit函数。

初始化

sklearn提供了多种estimator来做算法的第一步-初始化的工作。默认选用的是MeanEstimator,即使用均值来作为初始的预测值。其中fit()函数是计算了mean值,predict()将X样本的所有初始预测值y_pred设为之前计算的mean值。(其他初始化方法类似)

class MeanEstimator(BaseEstimator):
    """An estimator predicting the mean of the training targets."""
    def fit(self, X, y, sample_weight=None):
        if sample_weight is None:
            self.mean = np.mean(y)
        else:
            self.mean = np.average(y, weights=sample_weight)

    def predict(self, X):
        check_is_fitted(self, 'mean')

        y = np.empty((X.shape[0], 1), dtype=np.float64)
        y.fill(self.mean)
        return y

计算残差

sklearn提供了多种损失函数LossFunction,默认为最小平方损失LeastSquaresError,在损失函数中通过计算负梯度来计算‘‘伪残差’’。LeastSquaresError的负梯度negative_gradient(

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值