sklearn代码16 6-梯度提升树原理

import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

# 回归是分类的极限思想
# 分类的类别多到一定程度,就是回归

from sklearn.ensemble import GradientBoostingClassifier,GradientBoostingRegressor
from sklearn import tree

# X数据:上网时间和购物金额
# y目标:14 16, 24,26
X = np.array([[800,3],[1200,1],[1800,4],[2500,2]])

y = np.array([14,16,24,26])
gbdt= GradientBoostingClassifier(n_estimators=10)

gbdt.fit(X, y)

GradientBoostingClassifier(criterion='friedman_mse', init=None,
                           learning_rate=0.1, loss='deviance', max_depth=3,
                           max_features=None, max_leaf_nodes=None,
                           min_impurity_decrease=0.0, min_impurity_split=None,
                           min_samples_leaf=1, min_samples_split=2,
                           min_weight_fraction_leaf=0.0, n_estimators=10,
                           n_iter_no_change=None, presort='auto',
                           random_state=None, subsample=1.0, tol=0.0001,
                           validation_fraction=0.1, verbose=0,
                           warm_start=False)
gbdt.predict(X)
array([14, 16, 24, 26])
gbdt[0,0].predict(X)
array([ 3., -1., -1., -1.])
gbdt[-1,0].predict(X)
array([ 0.98250675, -0.81422807, -0.81422807, -0.81422807])
# 使用回归
gbdt = GradientBoostingRegressor(n_estimators=10)

gbdt.fit(X,y)

gbdt.predict(X)
array([ 16.09207064,  17.39471376,  22.60528624,  23.90792936])
#mse mean-square-error:均方误差 越小,说明预测出的值越准确
((y - y.mean())**2).mean()
26.0
((y[:2]-y[:2].mean())**2).mean()
1.0

第一棵树,根据平均值,计算量残差[-6,-4,4,6]

plt.rcParams['font.sans-serif']='KaiTi'
plt.figure(figsize=(9,6))
_= tree.plot_tree(gbdt[0,0],filled=True,feature_names=['消费','上网'])

请添加图片描述

# learning rate =0.1
gbdt1 = np.array([6,-4,4,6])

# 梯度提升 学习率0.1

gbdt1 - gbdt1*0.1
array([ 5.4, -3.6,  3.6,  5.4])

根据梯度提升,减少残差(残差越小,结果越好)

plt.rcParams['font.sans-serif']='KaiTi'
plt.figure(figsize=(9,6))
_= tree.plot_tree(gbdt[1,0],filled=True,feature_names=['消费','上网'])

请添加图片描述

plt.rcParams['font.sans-serif']='KaiTi'
plt.figure(figsize=(9,6))
_= tree.plot_tree(gbdt[2,0],filled=True,feature_names=['消费','上网'])

请添加图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值