12天summer----高级算法梳理-GBDT算法梳理

梯度提升回归树是另一种集成方法,通过合并多个决策树来构建一个更为强大的模型。虽然名字中含有“回归”,但这个模型既可以用于回归也可以用于分类。与随机森林方法不同,梯度提升采用连续的方式构造树,每棵树都试图纠正前一棵树的错误。默认情况下,梯度提升回归树中没有随机化,而是用到了强预剪枝。梯度提升树通常使用深度很小(1到 5 之间)的树,这样模型占用的内存更少,预测速度也更快。

梯度提升背后的主要思想是合并许多简单的模型(在这个语境中叫作弱学习器),比如深度较小的树。每棵树只能对部分数据做出好的预测,因此,添加的树越来越多,可以不断迭代提高性能。梯度提升树经常是机器学习竞赛的优胜者,并且广泛应用于业界。与随机森林相比,它通常对参数设置更为敏感,但如果参数设置正确的话,模型精度更高。

除了预剪枝与集成中树的数量之外,梯度提升的另一个重要参数是 learning_rate(学习率),用于控制每棵树纠正前一棵树的错误的强度。较高的学习率意味着每棵树都可以做出较强的修正,这样模型更为复杂。通过增大 n_estimators 来向集成中添加更多树,也可以增加模型复杂度,因为模型有更多机会纠正训练集上的错误。

from sklearn.ensemble import GradientBoostingClassifier 
from sklearn.datasets import load_breast_cancer
cancer=load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(cancer.data, cancer.target, random_state=0)  
gbrt = GradientBoostingClassifier(random_state=0) 
gbrt.fit(X_train, y_train)  
print("Accuracy on training set: {:.3f}".format(gbrt.score(X_train, y_train))) 
print("Accuracy on test set: {:.3f}".format(gbrt.score(X_test, y_test)))

 

由于训练集精度达到 100%,所以很可能存在过拟合。为了降低过拟合,我们可以限制最大深度来加强预剪枝,也可以降低学习率:  

gbrt = GradientBoostingClassifier(random_state=0, max_depth=1) 
gbrt.fit(X_train, y_train)  
print("Accuracy on training set: {:.3f}".format(gbrt.score(X_train, y_train))) 
print("Accuracy on test set: {:.3f}".format(gbrt.score(X_test, y_test)))

 

gbrt = GradientBoostingClassifier(random_state=0, learning_rate=0.01) 
gbrt.fit(X_train, y_train) 
print("Accuracy on training set: {:.3f}".format(gbrt.score(X_train, y_train))) 
print("Accuracy on test set: {:.3f}".format(gbrt.score(X_test, y_test)))

降低模型复杂度的两种方法都降低了训练集精度,这和预期相同。在这个例子中,减小树的最大深度显著提升了模型性能,而降低学习率仅稍稍提高了泛化性能。

排序特征重要性,以便更好地理解模型

importances_all={}
importances = gbrt.feature_importances_
feat_labels = cancer.columns[0:-1]
indices = np.argsort(importances)[::-1]
for f in range(X_train.shape[1]):
    print("%2d) %-*s %f" % (f + 1, 30, feat_labels[indices[f]], importances[indices[f]]))
    importances_all.update({feat_labels[indices[f]]:importances[indices[f]]})
#建立特征重要度的表
importances=pd.DataFrame(importances_all.values(),index=importances_all.keys())
#设定大于0.05阈值的特征显示
threshold = 0.05
x_selected =importances[importances[0]>0.05]

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值