sklearn.GBDT源码解读
2017/01/09 22:05 V.0.1 第一版不注重源码的细节把握,注重的是代码的整体把控。后续版本会更新具体源码细节部分。
2017/01/11 01:25 V.0.2 第一版不注重源码的细节把握,注重的是代码的整体把控。后续版本会更新具体源码细节部分。
最近一直玩数据挖掘,GBDT使用了一点,就想看看源码是怎么实现的。
当训练一个GBDT模型的时候
gbdt=sklearn.ensemble.GradientBoostingClassifier(param)
s所以我们找到对应文件夹的代码
class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
_SUPPORTED_LOSS = ('deviance', 'exponential')
def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
subsample=1.0, min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, init=None, random_state=None,
max_features=None, verbose=0,
max_leaf_nodes=None, warm_start=False,
presort='auto'):
super(GradientBoostingClassifier, self).__init__(
loss=loss, learning_rate=learning_rate, n_estimators=n_estimators,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_depth=max_depth, init=init, subsample=subsample,
max_features=max_features,
random_state=random_state, verbose=verbose,
max_leaf_nodes=max_leaf_nodes, warm_start=warm_start,
presort=presort)
w我们看到GradientBoostingClassifier继承了一个父类BaseGradientBoosting,同时我们也会发现GradientBoostingRegressor也继承了这个父类。我们会在这个父类中找到如下代码片:
class BaseGradientBoosting(six.with_metaclass(ABCMeta, BaseEnsemble,
_LearntSelectorMixin)):
.
.
.
.
def fit(self, X, y, sample_weight=None, monitor=None):
"""Fit the gradient boosting model.
a
x显然,当我们训练模型时调用的就是这个fit函数。
clf = clf.fit(train[predictors],train[target])
x下面我们深入到这个fit代码里面一探究竟,看看GBDT到底是怎么来训练模型的。其中比如warm_start,check_X_y,check_random_state等一看便是基本的检查数据合法性,直接略过不看。直接找到最重要的代码片:
# fit the boosting stages
n_stages = self._fit_stages(X, y, y_pred, sample_weight, random_state,
begin_at_stage, monitor, X_idx_sorted)
# change shape of arrays after fit (early-stopping or additional ests)
if n_stages != self.estimators_.shape[0]:
self.estimators_ = self.estimators_[:n_stages]
self.train_score_ = self.train_score_[:n_stages]
if hasattr(self, 'oob_im