对比XGBoost.cv和sklearn中的交叉验证

写在前面:已经很久很久很久没有发博客了,有点愧疚还有点难过,不写博客的实践都干嘛了,哎!!!

XGBoost有两种接口:

  1. 原生接口,比如xgboost.trainxgboost.cv
  2. sklearn接口,比如xgboost.XGBClassifierxgboost.XGBRegressor

两种接口有些许不同,比如原生接口的学习率参数是eta,sklearn接口的是learning_rate,原生接口要在traincv函数中传入num_round作为基学习器个数,而sklearn接口在定义模型时使用参数n_estimators。sklearn接口的形式与sklearn中的模型保持统一,方便sklearn用户学习。

如果要对XGBoost模型进行交叉验证,可以使用原生接口的交叉验证函数xgboost.cv;对于sklearn接口,可以使用sklearn.model_selection中的cross_val_scorecross_validatevalidation_curve三个函数。

sklearn.model_selection中的三个函数区别:

  1. cross_val_score最简单,返回模型给定参数的验证得分,不能返回训练得分
  2. cross_validate复杂一些,返回模型给定参数的训练得分、验证得分、训练时间和验证时间等,甚至还可以指定多个评价指标
  3. validation_curve返回模型指定一个参数的一系列候选值的训练得分和验证得分,可以通过判断拟合情况来调整该参数,也可以用来画validation_curve

下面分别以分类任务和回归任务展示一下四个函数的用法和输出情况。经过对比,在参数相同的条件下,四个函数的输出结果一致。发现了一个问题,validation_curvexgboost.cv的输出结果大部分相同,但是前者的耗时却比后者多了好几倍。(暂时还找到原因,网上也没找到相同的问题,打算到stackoverflow上问一下,如果有答案的话再回来补充)

20200402补充:初步怀疑是热启动的问题,在使用xgboost.cv进行交叉验证时,可以通过热启动的方式训练模型,此时只需要训练 N N N棵树;而把XGBRegressor传入validation_curve进行交叉验证,此时XGBRegressor不能设置热启动(而sklearn的GBDT和随机森林都可以设置热启动),那就需要训练 1 + 2 + . . . + N = N ∗ ( N − 1 ) 2 1+2+...+N = \frac{N*(N-1)}{2} 1+2+...+N=2N(N1)棵树,自然速度就慢了。

P.S. 下面代码是用Jupyter Notebook写的,懒得合并了。

import numpy as np
import xgboost as xgb

from sklearn.datasets import make_regression
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_validate
from sklearn.model_selection import KFold
from sklearn.model_selection import validation_curve
# 回归问题
X, y = make_regression(n_samples=10000, n_features=10)
# sklearn接口
n_estimators = 50
params = {'n_estimators':n_estimators, 'booster':'gbtree', 'max_depth':5, 'learning_rate':0.05,
          'objective':'reg:squarederror', 'subsample':1, 'colsample_bytree':1}
clf = xgb.XGBRegressor(**params)
cv = KFold(n_splits=5, shuffle=True, random_state=100)
print('test_score:', cross_val_score(clf, X, y, cv=cv, scoring='neg_mean_absolute_error'))
test_score: array([-57.48422753, -59.69255262, -58.91771172, -58.44347715,
       -59.8880623 ])
cross_validate(clf, X, y, cv=cv, scoring='neg_mean_absolute_error', 
               return_train_score=True)
{'fit_time': array([0.37278223, 0.36898613, 0.36637878, 0.36504936, 0.37162185]),
 'score_time': array([0.00398517, 0.00403547, 0.00398993, 0.00398636, 0.00404048]),
 'test_score': array([-57.48422753, -59.69255262, -58.91771172, -58.44347715,
        -59.8880623 ]),
 'train_score': array([-50.70099151, -50.43187094, -50.75229625, -50.66844022,
        -50.82982251])}
%%time # 计算一个cell的执行实践

estimator_range = range(1, n_estimators+1)
train_score, test_score = validation_curve(
    clf, X, y, param_name='n_estimators', param_range=estimator_range,
    cv=cv, scoring='neg_mean_absolute_error'
)

print('train_score:',train_score[-1])
print('test_score:', test_score[-1])
train_score: [-50.70099151 -50.43187094 -50.75229625 -50.66844022 -50.82982251]
test_score: [-57.48422753 -59.69255262 -58.91771172 -58.44347715 -59.8880623 ]
Wall time: 57 s
print('train_mae_mean:\n', np.abs(train_score).mean(axis=1))
print('test_mae_mean:\n', np.abs(test_score).mean(axis=1))
train_mae_mean: 
 array([127.5682212 , 124.17645861, 120.96190697, 117.93807824,
        115.06161926, 112.28746068, 109.60911311, 107.07263957,
        104.63554663, 102.2788341 ,  99.97895509,  97.82509892,
         95.73958223,  93.71896245,  91.79974093,  89.94817809,
         88.13495265,  86.37829884,  84.67090725,  83.05548799,
         81.46903821,  79.94032864,  78.44072613,  77.00488358,
         75.62191062,  74.24138916,  72.92121361,  71.66007955,
         70.41908351,  69.23718699,  68.06376522,  66.92736292,
         65.81928473,  64.73408044,  63.67274508,  62.65390845,
         61.66069004,  60.69802867,  59.72997524,  58.7870726 ,
         57.88241178,  57.01307807,  56.14014094,  55.30247271,
         54.48416963,  53.69873843,  52.91791742,  52.16280788,
         51.42670887,  50.67668428]),
test_mae_mean:
 array([127.83738044, 124.65000719, 121.72020148, 118.90983369,
        116.24294452, 113.69376675, 111.29388701, 108.94996321,
        106.7553701 , 104.62401193, 102.5608943 , 100.68648486,
         98.76550219,  96.97546939,  95.20893969,  93.55259092,
         91.9299438 ,  90.3413075 ,  88.76142948,  87.34316226,
         85.96043718,  84.62054143,  83.30115705,  82.07698107,
         80.89857637,  79.67939585,  78.52190061,  77.37787457,
         76.28248431,  75.24121599,  74.2093299 ,  73.21873113,
         72.19303325,  71.23265487,  70.33854865,  69.42902278,
         68.57191177,  67.73459769,  66.88130101,  66.05978781,
         65.26603807,  64.46357751,  63.70019472,  62.95398889,
         62.25243534,  61.56164243,  60.88819753,  60.20476192,
         59.55280602,  58.88520627])
%%time

params_xgb = params.copy() # 修改参数
num_round = params_xgb['n_estimators']
params_xgb['eta'] = params['learning_rate']
del params_xgb['n_estimators']
del params_xgb['learning_rate']

# xgboost原生接口 进行交叉验证
res = xgb.cv(params_xgb, xgb.DMatrix(X, y), num_round, folds=cv, metrics='mae')
print(res)
    train-mae-mean  train-mae-std  test-mae-mean  test-mae-std
0       127.568312       0.315528     127.837350      1.243183
1       124.176437       0.300477     124.649957      1.236916
2       120.962018       0.301030     121.720238      1.206761
3       117.938005       0.278763     118.909902      1.231662
4       115.061696       0.269224     116.242946      1.190097
5       112.287560       0.240412     113.693771      1.159047
6       109.609152       0.262167     111.293890      1.099815
7       107.072640       0.242916     108.949971      1.067070
8       104.635579       0.209314     106.755350      1.080068
9       102.278841       0.195815     104.624013      1.054731
10       99.978919       0.201804     102.560906      1.055403
11       97.825169       0.213528     100.686517      1.033271
12       95.739612       0.202356      98.765524      1.029646
13       93.719107       0.187538      96.975470      1.005893
14       91.799744       0.175199      95.208905      1.046983
15       89.948177       0.160738      93.552597      1.067333
16       88.134976       0.144838      91.929965      1.052541
17       86.378351       0.163211      90.341278      1.037858
18       84.670908       0.187184      88.761414      0.995875
19       83.055446       0.171080      87.343141      0.981363
20       81.469022       0.164968      85.960420      0.993623
21       79.940317       0.167554      84.620523      0.963820
22       78.440726       0.154343      83.301137      1.004986
23       77.004854       0.141827      82.076961      0.986129
24       75.621930       0.150028      80.898605      0.964261
25       74.241496       0.154140      79.679413      0.949695
26       72.921170       0.140105      78.521875      0.946750
27       71.660085       0.130937      77.377856      0.924869
28       70.419052       0.109023      76.282506      0.928389
29       69.237167       0.107013      75.241214      0.900845
30       68.063844       0.097079      74.209323      0.900476
31       66.927363       0.091163      73.218730      0.942131
32       65.819266       0.091109      72.193025      0.930880
33       64.734090       0.092792      71.232658      0.908819
34       63.672701       0.086543      70.338522      0.932795
35       62.653945       0.088487      69.429022      0.927500
36       61.660666       0.082703      68.571904      0.915664
37       60.697992       0.119144      67.734601      0.882644
38       59.729960       0.126423      66.881299      0.886910
39       58.787107       0.117820      66.059784      0.897685
40       57.882377       0.125402      65.266035      0.877481
41       57.013075       0.109192      64.463574      0.901940
42       56.140131       0.140454      63.700203      0.888990
43       55.302481       0.148805      62.953973      0.834368
44       54.484136       0.145519      62.252445      0.829440
45       53.698661       0.132748      61.561636      0.854725
46       52.917877       0.124366      60.888204      0.875071
47       52.162859       0.133974      60.204764      0.878531
48       51.426765       0.140143      59.552805      0.892451
49       50.676675       0.133987      58.885213      0.873657
Wall time: 2.25 s

validation_curve用了57s,而xgboost.cv只用了2.25s,差距巨大!

# 分类数据集
X, y = make_classification(n_samples=10000, n_features=10, n_classes=2)
n_estimators = 50
params = {'n_estimators':n_estimators, 'booster':'gbtree', 'max_depth':5, 'learning_rate':0.05,
          'objective':'binary:logistic', 'subsample':1, 'colsample_bytree':1}
clf = xgb.XGBClassifier(**params)
cv = KFold(n_splits=5, shuffle=True, random_state=100)
print('test_score:', cross_val_score(clf, X, y, cv=cv, scoring='accuracy'))
test_score: array([0.913 , 0.9235, 0.8955, 0.9075, 0.918 ])
cross_validate(clf, X, y, cv=cv, scoring='accuracy', 
               return_train_score=True)
{'fit_time': array([0.43403697, 0.43297029, 0.41813326, 0.42408895, 0.42200208]),
 'score_time': array([0.00299048, 0.00203776, 0.00500631, 0.0019989 , 0.00299263]),
 'test_score': array([0.913 , 0.9235, 0.8955, 0.9075, 0.918 ]),
 'train_score': array([0.92425 , 0.921125, 0.9285  , 0.92325 , 0.922125])}
%%time

estimator_range = range(1, n_estimators+1)
train_score, test_score = validation_curve(
    clf, X, y, param_name='n_estimators', param_range=estimator_range,
    cv=cv, scoring='accuracy'
)

print('train_score:',train_score[-1])
print('test_score:', test_score[-1])
train_score: [0.92425  0.921125 0.9285   0.92325  0.922125]
test_score: [0.913  0.9235 0.8955 0.9075 0.918 ]
Wall time: 58.7 s
print('train_mae_mean:\n', np.abs(train_score).mean(axis=1))
print('test_mae_mean:\n', np.abs(test_score).mean(axis=1))
train_score.mean(axis=1), test_score.mean(axis=1)
train_mae_mean:
 array([0.912775, 0.916075, 0.91585 , 0.91695 , 0.917125, 0.917225,
        0.91725 , 0.9175  , 0.91745 , 0.917925, 0.91755 , 0.918025,
        0.917975, 0.91835 , 0.918225, 0.918625, 0.919   , 0.91905 ,
        0.918975, 0.9191  , 0.91955 , 0.919525, 0.9198  , 0.9199  ,
        0.919975, 0.920025, 0.9201  , 0.92005 , 0.920125, 0.9208  ,
        0.921425, 0.9218  , 0.921875, 0.922025, 0.922125, 0.9221  ,
        0.92225 , 0.922275, 0.922275, 0.92235 , 0.9226  , 0.9229  ,
        0.923   , 0.9233  , 0.923375, 0.923275, 0.923325, 0.9234  ,
        0.923675, 0.92385 ]),
test_mae_mean:
 array([0.9049, 0.9072, 0.9082, 0.9085, 0.9087, 0.9084, 0.9082, 0.9091,
        0.9087, 0.9089, 0.9091, 0.9092, 0.9089, 0.9101, 0.9102, 0.9108,
        0.9102, 0.9107, 0.9105, 0.9109, 0.9104, 0.9102, 0.9109, 0.9109,
        0.9103, 0.9105, 0.9105, 0.9103, 0.9106, 0.9111, 0.9121, 0.9124,
        0.9124, 0.9122, 0.9119, 0.912 , 0.912 , 0.9117, 0.9114, 0.911 ,
        0.911 , 0.9113, 0.9111, 0.9107, 0.9108, 0.911 , 0.9109, 0.9113,
        0.9114, 0.9115])
%%time

params_xgb = params.copy()
num_round = params_xgb['n_estimators']
params_xgb['eta'] = params['learning_rate']
del params_xgb['n_estimators']
del params_xgb['learning_rate']

res = xgb.cv(params_xgb, xgb.DMatrix(X, y), num_round, folds=cv, metrics='error')
Wall time: 2.37 s
res['train-error-mean'] = 1 - res['train-error-mean']
res['test-error-mean'] = 1 - res['test-error-mean']
print(res)
    train-error-mean  train-error-std  test-error-mean  test-error-std
0           0.912775         0.002296           0.9049        0.007493
1           0.916075         0.003749           0.9072        0.007679
2           0.915850         0.003048           0.9082        0.006615
3           0.916950         0.002090           0.9085        0.008503
4           0.917125         0.002028           0.9087        0.008606
5           0.917225         0.002191           0.9084        0.009356
6           0.917250         0.002219           0.9082        0.009114
7           0.917500         0.002318           0.9091        0.009672
8           0.917450         0.002308           0.9087        0.009405
9           0.917925         0.002467           0.9089        0.009410
10          0.917550         0.002248           0.9091        0.009313
11          0.918025         0.002384           0.9092        0.009389
12          0.917975         0.002583           0.9089        0.009124
13          0.918350         0.002095           0.9101        0.008840
14          0.918225         0.002223           0.9102        0.008658
15          0.918625         0.002204           0.9108        0.008388
16          0.919000         0.002904           0.9102        0.009495
17          0.919050         0.002639           0.9107        0.008376
18          0.918975         0.002451           0.9105        0.008562
19          0.919100         0.002613           0.9109        0.008645
20          0.919550         0.003244           0.9104        0.008570
21          0.919525         0.003234           0.9102        0.008761
22          0.919800         0.003307           0.9109        0.008505
23          0.919900         0.003537           0.9109        0.008505
24          0.919975         0.003535           0.9103        0.008376
25          0.920025         0.003365           0.9105        0.008087
26          0.920100         0.003451           0.9105        0.008390
27          0.920050         0.003514           0.9103        0.008412
28          0.920125         0.003521           0.9106        0.007908
29          0.920800         0.003303           0.9111        0.008351
30          0.921425         0.002912           0.9121        0.009330
31          0.921800         0.002910           0.9124        0.009330
32          0.921875         0.002739           0.9124        0.009330
33          0.922025         0.002837           0.9122        0.009405
34          0.922125         0.002860           0.9119        0.009957
35          0.922100         0.002807           0.9120        0.009497
36          0.922250         0.002777           0.9120        0.009370
37          0.922275         0.002636           0.9117        0.009569
38          0.922275         0.002540           0.9114        0.009609
39          0.922350         0.002477           0.9110        0.009680
40          0.922600         0.002607           0.9110        0.009633
41          0.922900         0.002838           0.9113        0.010033
42          0.923000         0.002787           0.9111        0.009805
43          0.923300         0.002612           0.9107        0.009516
44          0.923375         0.002614           0.9108        0.009667
45          0.923275         0.002897           0.9110        0.009772
46          0.923325         0.002718           0.9109        0.009728
47          0.923400         0.002685           0.9113        0.009811
48          0.923675         0.002820           0.9114        0.009764
49          0.923850         0.002551           0.9115        0.009597

validation_curve用了58.7s,而xgboost.cv只用了2.37s,差距巨大!

  • 11
    点赞
  • 48
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值