XGBoost 1 - 基础及简单调用

XGBoost

extreme gradient boosting, 是gradient boosting machine的优化实现,快速有效。

  • xgboost简介
    • xgboost特点
    • xgboost基本使用指南
  • xgboost理论基础
    • supervise learning
    • CART
    • boosting
    • gradient boosting
    • xgboost
  • xgboost实战
    • 特征工程
    • 参数调优

1 - 初识xgboost

  • 简介
  • 优势
  • 实战

1.1 - 简介

xgboost是gradient boosting machine的C++优化实现,gradient boosting machine的含义:

  • machine:机器学习模型, 对数据产生的规律进行建模
  • boosting machine:是一种弱学习器组合成强学习器的模型
  • gradient boosting machine:根据梯度下降的方式组合弱学习器

1.1.1 - machine

对数据进行产生规律进行建模的问题通常会形式化为一个最小化目标函数的问题,目标函数通常有两个部分组成:

目标函数:

  • 损失函数(选择与训练数据匹配最好的模型)
    • 回归: (y^y)2 ( y ^ − y ) 2
    • 分类: 0-1损失,logistic损失,合页损失,指数损失
  • 正则项(选择最简单的模型)
    • L2正则
    • L1正则

1.1.2 - boosting machine

  • boosting: 将弱学习器组合成强学习器
  • 弱学习器:决策树/分类回归树(xgboost的弱学习器)
    • 决策树:叶子结点对应决策
    • 分类回归树:叶子节点对应预测分数


这里写图片描述

1.1.3 - gradient boosting machine

经典的boosting machine算法是Adaboost,adaboost的损失函数是指数损失, friedman将adaboost推广到一般的gradient boosting框架,得到gradient boosting machine:将boosting视为一个数值优化的问题,采用类似于梯度下降的方式来进行求解,这样可以使用任何可微的损失函数,支持的任务从两类分类扩宽到多类分类等。

1.2 - xgboost优势

特点:

  • 正则化:标准的GBM(gradient boosting machine)没有显式的正则化
  • 并行
  • 自定义优化目标和评价准则:需要损失函数的一阶导数和二阶导数
  • 剪枝:当新增分裂带来的是负收益的时候,GBM会停止分裂,xgboost会一直分裂到最大的深度,然后剪枝
  • 支持在线学习
  • 在结构化数据上表现突出,深度学习在非机构化的数据上表现好。

1.3 - 实战xgboost

1.3.1 - 处理数据科学任务的一般流程


这里写图片描述

1.3.2 - 基于sklearn框架的xgboost使用

from sklearn.datasets import load_svmlight_file
from sklearn.metrics import accuracy_score
from xgboost import XGBClassifier

数据读取

file_path  = "./data/"
X_train, y_train = load_svmlight_file(file_path+"agaricus.txt.train")
X_test, y_test = load_svmlight_file(file_path+"agaricus.txt.test")
print(X_train.shape, y_train.shape)
(6513, 126) (6513,)
print(X_test.shape, y_test.shape)
(1611, 126) (1611,)

参数介绍:

  • max_depth: 树的最大深度。缺省值为6
  • learning_rate:为了防止过拟合,更新过程中用到的收缩步长。在每次提升计算之后,算法会直接获得新特征的权重。 learning_rate通过缩减特征的权重使提升计算过程更加保守。缺省值为0.3,取值范围为:[0,1]
  • slient:取0时表示打印出运行时信息,取1时表示以缄默方式运行,不打印运行时信息。缺省值为0
  • objective: 定义学习任务及相应的学习目标,“binary:logistic” 表示二分类的逻辑回归问题,输出为概率

配置模型

xgbc = XGBClassifier(max_depth=2, 
                     learning_rate=1, 
                     n_estimators=2, # number of iterations or number of trees
                     slient=0,
                     objective="binary:logistic"
                    )

训练模型

xgbc.fit(X_train, y_train)
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bytree=1, gamma=0, learning_rate=1, max_delta_step=0,
       max_depth=2, min_child_weight=1, missing=None, n_estimators=2,
       n_jobs=1, nthread=None, objective='binary:logistic', random_state=0,
       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
       silent=True, slient=0, subsample=1)

训练误差

pred_train = xgbc.predict(X_train)
pred_train = [round(x) for x in pred_train]
train_score = accuracy_score(y_train, pred_train)
print("Train Accuracy: %.2f%%" % (train_score * 100))
Train Accuracy: 97.77%

测试误差

pred_test = xgbc.predict(X_test)
pred_test = [1 if x >= 0.5 else 0 for x in pred_test]
print("Test Accuracy: %.2f%%" % (accuracy_score(y_test, pred_test) * 100))
Test Accuracy: 97.83% 

1.3.3 - 验证集

将训练数据的一部分留出来,不参与模型参数训练。留出来的这部分
数据称为验证集(validation set)

  • 余下的数据训练模型,训练好的模型在验证集上测试
  • 校验集上的性能可视为模型在未知数据上性能的估计,选择在校验集上表现最好的模型
from sklearn.model_selection import train_test_split

划分训练集和验证集

file_path  = "./data/"
X_train, y_train = load_svmlight_file(file_path+"agaricus.txt.train")
X_test, y_test = load_svmlight_file(file_path+"agaricus.txt.test")

X_train, X_validation, y_train, y_validation = train_test_split(X_train, y_train, test_size= 0.33, random_state=42)
print(X_train.shape, y_train.shape)
print(X_validation.shape, y_validation.shape)
(4363, 126) (4363,)
(2150, 126) (2150,)
xgbc = XGBClassifier(max_depth=2, learning_rate=1, n_estimators=2, slient=False, objective="binary:logistic")
xgbc.fit(X_train, y_train, verbose=True)
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bytree=1, gamma=0, learning_rate=1, max_delta_step=0,
       max_depth=2, min_child_weight=1, missing=None, n_estimators=2,
       n_jobs=1, nthread=None, objective='binary:logistic', random_state=0,
       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
       silent=True, slient=False, subsample=1)
# performance in validation set
pred_val = xgbc.predict(X_validation)
pred_val = [round(x) for x in pred_val]
print("Validation Accuracy: %2.f%%" % (accuracy_score(pred_val, y_validation) * 100))
Validation Accuracy: 97%
# performance in train set
pred_train = xgbc.predict(X_train)
pred_train = [round(x) for x in pred_train]
print("Validation Accuracy: %2.f%%" % (accuracy_score(pred_train, y_train) * 100))
Validation Accuracy: 98%
# performance in test set
pred_test = xgbc.predict(X_test)
pred_test = [round(x) for x in pred_test]
print("Validation Accuracy: %2.f%%" % (accuracy_score(pred_test, y_test) * 100))
Validation Accuracy: 98%

1.3.4 - 学习曲线-关于弱分类器的个数或者说迭代的次数

import matplotlib.pyplot as plt
n_iteration = 100

xgbc = XGBClassifier(max_depth=2, learning_rate=0.1, n_estimators=n_iteration, objective="binary:logistic")
eval_set = [(X_train, y_train), (X_validation, y_validation)]
xgbc.fit(X_train, y_train, eval_set=eval_set, eval_metric=["error", "logloss"], verbose=True)
[0] validation_0-error:0.044236 validation_0-logloss:0.614162   validation_1-error:0.051163 validation_1-logloss:0.615457
[1] validation_0-error:0.039193 validation_0-logloss:0.549179   validation_1-error:0.046512 validation_1-logloss:0.551203
[2] validation_0-error:0.044236 validation_0-logloss:0.494366   validation_1-error:0.051163 validation_1-logloss:0.497442
[3] validation_0-error:0.039193 validation_0-logloss:0.447845   validation_1-error:0.046512 validation_1-logloss:0.451486
[4] validation_0-error:0.039193 validation_0-logloss:0.407646   validation_1-error:0.046512 validation_1-logloss:0.411989
[5] validation_0-error:0.039193 validation_0-logloss:0.371941   validation_1-error:0.046512 validation_1-logloss:0.377037
[6] validation_0-error:0.022003 validation_0-logloss:0.341067   validation_1-error:0.026047 validation_1-logloss:0.346286
[7] validation_0-error:0.039193 validation_0-logloss:0.313232   validation_1-error:0.046512 validation_1-logloss:0.319077
[8] validation_0-error:0.039193 validation_0-logloss:0.288775   validation_1-error:0.046512 validation_1-logloss:0.294526
[9] validation_0-error:0.022003 validation_0-logloss:0.267046   validation_1-error:0.026047 validation_1-logloss:0.273228
[10]    validation_0-error:0.004813 validation_0-logloss:0.247238   validation_1-error:0.008372 validation_1-logloss:0.253542
[11]    validation_0-error:0.004813 validation_0-logloss:0.229689   validation_1-error:0.008372 validation_1-logloss:0.236248
[12]    validation_0-error:0.010085 validation_0-logloss:0.210475   validation_1-error:0.015349 validation_1-logloss:0.216868
[13]    validation_0-error:0.015586 validation_0-logloss:0.193727   validation_1-error:0.02093  validation_1-logloss:0.199968
[14]    validation_0-error:0.015586 validation_0-logloss:0.179108   validation_1-error:0.02093  validation_1-logloss:0.185209
[15]    validation_0-error:0.015586 validation_0-logloss:0.166333   validation_1-error:0.02093  validation_1-logloss:0.172308
[16]    validation_0-error:0.015586 validation_0-logloss:0.15516    validation_1-error:0.02093  validation_1-logloss:0.16102
[17]    validation_0-error:0.015586 validation_0-logloss:0.145382   validation_1-error:0.02093  validation_1-logloss:0.151137
[18]    validation_0-error:0.015586 validation_0-logloss:0.13682    validation_1-error:0.02093  validation_1-logloss:0.142481
[19]    validation_0-error:0.015586 validation_0-logloss:0.129854   validation_1-error:0.02093  validation_1-logloss:0.135452
[20]    validation_0-error:0.015586 validation_0-logloss:0.122889   validation_1-error:0.02093  validation_1-logloss:0.128415
[21]    validation_0-error:0.023608 validation_0-logloss:0.11724    validation_1-error:0.029302 validation_1-logloss:0.122718
[22]    validation_0-error:0.023608 validation_0-logloss:0.111548   validation_1-error:0.029302 validation_1-logloss:0.116973
[23]    validation_0-error:0.02017  validation_0-logloss:0.106935   validation_1-error:0.024186 validation_1-logloss:0.112492
[24]    validation_0-error:0.02017  validation_0-logloss:0.102711   validation_1-error:0.024186 validation_1-logloss:0.108251
[25]    validation_0-error:0.02017  validation_0-logloss:0.098366   validation_1-error:0.024186 validation_1-logloss:0.103854
[26]    validation_0-error:0.02017  validation_0-logloss:0.094848   validation_1-error:0.024186 validation_1-logloss:0.100122
[27]    validation_0-error:0.02017  validation_0-logloss:0.09125    validation_1-error:0.024186 validation_1-logloss:0.096787
[28]    validation_0-error:0.02017  validation_0-logloss:0.087968   validation_1-error:0.024186 validation_1-logloss:0.093459
[29]    validation_0-error:0.02017  validation_0-logloss:0.084816   validation_1-error:0.024186 validation_1-logloss:0.090229
[30]    validation_0-error:0.02017  validation_0-logloss:0.081983   validation_1-error:0.024186 validation_1-logloss:0.087354
[31]    validation_0-error:0.02017  validation_0-logloss:0.079313   validation_1-error:0.024186 validation_1-logloss:0.084619
[32]    validation_0-error:0.012148 validation_0-logloss:0.074708   validation_1-error:0.015814 validation_1-logloss:0.080086
[33]    validation_0-error:0.012148 validation_0-logloss:0.071661   validation_1-error:0.015814 validation_1-logloss:0.077247
[34]    validation_0-error:0.02017  validation_0-logloss:0.069014   validation_1-error:0.024186 validation_1-logloss:0.074588
[35]    validation_0-error:0.014669 validation_0-logloss:0.06648    validation_1-error:0.018605 validation_1-logloss:0.072239
[36]    validation_0-error:0.009397 validation_0-logloss:0.064195   validation_1-error:0.011628 validation_1-logloss:0.069621
[37]    validation_0-error:0.001375 validation_0-logloss:0.062203   validation_1-error:0.003256 validation_1-logloss:0.06757
[38]    validation_0-error:0.001375 validation_0-logloss:0.060052   validation_1-error:0.003256 validation_1-logloss:0.065462
[39]    validation_0-error:0.001375 validation_0-logloss:0.05799    validation_1-error:0.003256 validation_1-logloss:0.063569
[40]    validation_0-error:0.001375 validation_0-logloss:0.056169   validation_1-error:0.003256 validation_1-logloss:0.061491
[41]    validation_0-error:0.001375 validation_0-logloss:0.054376   validation_1-error:0.003256 validation_1-logloss:0.059743
[42]    validation_0-error:0.009397 validation_0-logloss:0.052657   validation_1-error:0.011628 validation_1-logloss:0.058177
[43]    validation_0-error:0.001375 validation_0-logloss:0.051002   validation_1-error:0.003256 validation_1-logloss:0.056733
[44]    validation_0-error:0.001375 validation_0-logloss:0.049429   validation_1-error:0.003256 validation_1-logloss:0.054922
[45]    validation_0-error:0.001375 validation_0-logloss:0.047924   validation_1-error:0.003256 validation_1-logloss:0.053362
[46]    validation_0-error:0.001375 validation_0-logloss:0.046491   validation_1-error:0.003256 validation_1-logloss:0.051973
[47]    validation_0-error:0.001375 validation_0-logloss:0.045115   validation_1-error:0.003256 validation_1-logloss:0.050731
[48]    validation_0-error:0.001375 validation_0-logloss:0.04384    validation_1-error:0.003256 validation_1-logloss:0.049218
[49]    validation_0-error:0.001375 validation_0-logloss:0.04261    validation_1-error:0.003256 validation_1-logloss:0.048026
[50]    validation_0-error:0.001375 validation_0-logloss:0.041414   validation_1-error:0.003256 validation_1-logloss:0.046635
[51]    validation_0-error:0.001375 validation_0-logloss:0.04024    validation_1-error:0.003256 validation_1-logloss:0.04559
[52]    validation_0-error:0.001375 validation_0-logloss:0.039108   validation_1-error:0.003256 validation_1-logloss:0.044651
[53]    validation_0-error:0.001375 validation_0-logloss:0.038046   validation_1-error:0.003256 validation_1-logloss:0.043404
[54]    validation_0-error:0.001375 validation_0-logloss:0.036975   validation_1-error:0.003256 validation_1-logloss:0.042286
[55]    validation_0-error:0.001375 validation_0-logloss:0.035982   validation_1-error:0.003256 validation_1-logloss:0.041341
[56]    validation_0-error:0.001375 validation_0-logloss:0.035031   validation_1-error:0.003256 validation_1-logloss:0.040505
[57]    validation_0-error:0.001375 validation_0-logloss:0.034135   validation_1-error:0.003256 validation_1-logloss:0.039399
[58]    validation_0-error:0.001375 validation_0-logloss:0.033276   validation_1-error:0.003256 validation_1-logloss:0.038583
[59]    validation_0-error:0.001375 validation_0-logloss:0.032452   validation_1-error:0.003256 validation_1-logloss:0.037861
[60]    validation_0-error:0.001375 validation_0-logloss:0.031655   validation_1-error:0.003256 validation_1-logloss:0.036928
[61]    validation_0-error:0.001375 validation_0-logloss:0.030869   validation_1-error:0.003256 validation_1-logloss:0.035987
[62]    validation_0-error:0.001375 validation_0-logloss:0.030057   validation_1-error:0.003256 validation_1-logloss:0.035138
[63]    validation_0-error:0.001375 validation_0-logloss:0.029379   validation_1-error:0.003256 validation_1-logloss:0.034418
[64]    validation_0-error:0.001375 validation_0-logloss:0.028683   validation_1-error:0.003256 validation_1-logloss:0.033762
[65]    validation_0-error:0.001375 validation_0-logloss:0.028014   validation_1-error:0.003256 validation_1-logloss:0.033187
[66]    validation_0-error:0.001375 validation_0-logloss:0.027338   validation_1-error:0.003256 validation_1-logloss:0.032326
[67]    validation_0-error:0.001375 validation_0-logloss:0.026727   validation_1-error:0.003256 validation_1-logloss:0.031581
[68]    validation_0-error:0.001375 validation_0-logloss:0.026087   validation_1-error:0.003256 validation_1-logloss:0.031107
[69]    validation_0-error:0.001375 validation_0-logloss:0.025474   validation_1-error:0.003256 validation_1-logloss:0.030427
[70]    validation_0-error:0.001375 validation_0-logloss:0.024911   validation_1-error:0.003256 validation_1-logloss:0.029905
[71]    validation_0-error:0.001375 validation_0-logloss:0.024368   validation_1-error:0.003256 validation_1-logloss:0.029239
[72]    validation_0-error:0.001375 validation_0-logloss:0.023829   validation_1-error:0.003256 validation_1-logloss:0.028852
[73]    validation_0-error:0.001375 validation_0-logloss:0.023316   validation_1-error:0.003256 validation_1-logloss:0.028419
[74]    validation_0-error:0.001375 validation_0-logloss:0.02278    validation_1-error:0.003256 validation_1-logloss:0.027854
[75]    validation_0-error:0.001375 validation_0-logloss:0.022305   validation_1-error:0.003256 validation_1-logloss:0.027263
[76]    validation_0-error:0.001375 validation_0-logloss:0.021837   validation_1-error:0.003256 validation_1-logloss:0.026841
[77]    validation_0-error:0.001375 validation_0-logloss:0.02139    validation_1-error:0.003256 validation_1-logloss:0.02647
[78]    validation_0-error:0.001375 validation_0-logloss:0.020914   validation_1-error:0.003256 validation_1-logloss:0.02589
[79]    validation_0-error:0.001375 validation_0-logloss:0.020452   validation_1-error:0.003256 validation_1-logloss:0.025369
[80]    validation_0-error:0.001375 validation_0-logloss:0.020058   validation_1-error:0.003256 validation_1-logloss:0.024872
[81]    validation_0-error:0.001375 validation_0-logloss:0.019648   validation_1-error:0.003256 validation_1-logloss:0.024367
[82]    validation_0-error:0.001375 validation_0-logloss:0.019268   validation_1-error:0.003256 validation_1-logloss:0.023936
[83]    validation_0-error:0.001375 validation_0-logloss:0.018878   validation_1-error:0.003256 validation_1-logloss:0.023496
[84]    validation_0-error:0.001375 validation_0-logloss:0.018503   validation_1-error:0.003256 validation_1-logloss:0.023169
[85]    validation_0-error:0.001375 validation_0-logloss:0.018148   validation_1-error:0.003256 validation_1-logloss:0.022877
[86]    validation_0-error:0.001375 validation_0-logloss:0.017783   validation_1-error:0.003256 validation_1-logloss:0.022427
[87]    validation_0-error:0.001375 validation_0-logloss:0.01746    validation_1-error:0.003256 validation_1-logloss:0.022145
[88]    validation_0-error:0.001375 validation_0-logloss:0.017149   validation_1-error:0.003256 validation_1-logloss:0.021805
[89]    validation_0-error:0.001375 validation_0-logloss:0.016832   validation_1-error:0.003256 validation_1-logloss:0.021546
[90]    validation_0-error:0.001375 validation_0-logloss:0.016305   validation_1-error:0.003256 validation_1-logloss:0.020802
[91]    validation_0-error:0.001375 validation_0-logloss:0.016013   validation_1-error:0.003256 validation_1-logloss:0.020549
[92]    validation_0-error:0.001375 validation_0-logloss:0.015729   validation_1-error:0.003256 validation_1-logloss:0.02018
[93]    validation_0-error:0.001375 validation_0-logloss:0.015467   validation_1-error:0.003256 validation_1-logloss:0.019926
[94]    validation_0-error:0.001375 validation_0-logloss:0.015202   validation_1-error:0.003256 validation_1-logloss:0.019611
[95]    validation_0-error:0.001375 validation_0-logloss:0.014931   validation_1-error:0.003256 validation_1-logloss:0.019267
[96]    validation_0-error:0.001375 validation_0-logloss:0.014652   validation_1-error:0.003256 validation_1-logloss:0.018949
[97]    validation_0-error:0.001375 validation_0-logloss:0.014399   validation_1-error:0.003256 validation_1-logloss:0.018651
[98]    validation_0-error:0.001375 validation_0-logloss:0.014151   validation_1-error:0.003256 validation_1-logloss:0.018445
[99]    validation_0-error:0.001375 validation_0-logloss:0.013908   validation_1-error:0.003256 validation_1-logloss:0.018252





XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bytree=1, gamma=0, learning_rate=0.1, max_delta_step=0,
       max_depth=2, min_child_weight=1, missing=None, n_estimators=100,
       n_jobs=1, nthread=None, objective='binary:logistic', random_state=0,
       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
       silent=True, subsample=1)
plt.rcParams["figure.figsize"] = (5., 3.)
result = xgbc.evals_result()

epochs = len(result["validation_0"]["error"])

fig, ax = plt.subplots()
ax.plot(list(range(epochs)), result["validation_0"]["error"], label="train")
ax.plot(list(range(epochs)), result["validation_1"]["error"], label="validation")
ax.legend()
plt.ylabel("error")
plt.xlabel("epoch")
plt.title("XGBoost error")
plt.show()

fig, ax = plt.subplots()
ax.plot(list(range(epochs)), result["validation_0"]["logloss"], label="train")
ax.plot(list(range(epochs)), result["validation_1"]["logloss"], label="validation")
ax.legend()
plt.ylabel("logloss")
plt.xlabel("epoch")
plt.title("XGBoost logloss")
plt.show()


这里写图片描述


这里写图片描述

# performance in the test set
pred_test = xgbc.predict(X_test)
pred_test = [round(x) for x in pred_test]
print("Test Accuracy: %.2f%%" % (accuracy_score(y_test, pred_test) * 100))
Test Accuracy: 99.81%

1.3.5 - early stop

一种防止过拟合的方法
- 监控模型在校验集上的性能:如果在经过固定次数的迭代,校验集上的性能不再提高时,结束训练过程

eval_set = [(X_validation, y_validation)]
xgbc.fit(X_train, y_train, eval_set=eval_set, eval_metric="error", early_stopping_rounds=10, verbose=True)
[0] validation_0-error:0.051163
Will train until validation_0-error hasn't improved in 10 rounds.
[1] validation_0-error:0.046512
[2] validation_0-error:0.051163
[3] validation_0-error:0.046512
[4] validation_0-error:0.046512
[5] validation_0-error:0.046512
[6] validation_0-error:0.026047
[7] validation_0-error:0.046512
[8] validation_0-error:0.046512
[9] validation_0-error:0.026047
[10]    validation_0-error:0.008372
[11]    validation_0-error:0.008372
[12]    validation_0-error:0.015349
[13]    validation_0-error:0.02093
[14]    validation_0-error:0.02093
[15]    validation_0-error:0.02093
[16]    validation_0-error:0.02093
[17]    validation_0-error:0.02093
[18]    validation_0-error:0.02093
[19]    validation_0-error:0.02093
[20]    validation_0-error:0.02093
Stopping. Best iteration:
[10]    validation_0-error:0.008372


XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bytree=1, gamma=0, learning_rate=0.1, max_delta_step=0,
       max_depth=2, min_child_weight=1, missing=None, n_estimators=100,
       n_jobs=1, nthread=None, objective='binary:logistic', random_state=0,
       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
       silent=True, subsample=1)
result = xgbc.evals_result()
plt.plot(list(range(len(result["validation_0"]["error"]))), result["validation_0"]["error"])
plt.ylabel("error")
plt.title("XGBoost error-early stop")
plt.show()


这里写图片描述

pred_test = xgbc.predict(X_test)
pred_test = [1 if x >= 0.5 else 0 for x in pred_test]
print("Train Accuracy: %.4f" % (accuracy_score(pred_test, y_test)))
Train Accuracy: 0.9808

1.3.6 - 交叉验证cross validation

from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
from sklearn import preprocessing
import warnings
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
kflods = StratifiedKFold(n_splits=10, random_state=42)
print(kflods)
results = cross_val_score(xgbc, X_train, y_train, cv=kflods)
StratifiedKFold(n_splits=10, random_state=42, shuffle=False)
print(results)
print("%.2f%%, %.2f%%" % (results.mean() * 100, results.std() *100))
[0.99771167 0.99771167 1.         1.         0.99542334 0.99770642
 0.99770642 1.         1.         1.        ]
99.86%, 0.15%

1.3.7 - GridSearchCV

from sklearn.grid_search import GridSearchCV
xgbc = XGBClassifier(max_depth=2, objective="binary:logistic")
param_search = {
    "n_estimators":list(range(1, 10, 1)),
    "learning_rate":[x/10 for x in list(range(1, 11, 1))]
}
clf = GridSearchCV(estimator=xgbc, param_grid=param_search, cv=5)
clf.fit(X_train, y_train)
GridSearchCV(cv=5, error_score='raise',
       estimator=XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bytree=1, gamma=0, learning_rate=0.1, max_delta_step=0,
       max_depth=2, min_child_weight=1, missing=None, n_estimators=100,
       n_jobs=1, nthread=None, objective='binary:logistic', random_state=0,
       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
       silent=True, subsample=1),
       fit_params={}, iid=True, n_jobs=1,
       param_grid={'learning_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 'n_estimators': [1, 2, 3, 4, 5, 6, 7, 8, 9]},
       pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=0)
clf.grid_scores_
[mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.1, 'n_estimators': 1},
 mean: 0.95966, std: 0.01175, params: {'learning_rate': 0.1, 'n_estimators': 2},
 mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.1, 'n_estimators': 3},
 mean: 0.95874, std: 0.01161, params: {'learning_rate': 0.1, 'n_estimators': 4},
 mean: 0.95966, std: 0.01175, params: {'learning_rate': 0.1, 'n_estimators': 5},
 mean: 0.95966, std: 0.01175, params: {'learning_rate': 0.1, 'n_estimators': 6},
 mean: 0.96997, std: 0.01554, params: {'learning_rate': 0.1, 'n_estimators': 7},
 mean: 0.96379, std: 0.01191, params: {'learning_rate': 0.1, 'n_estimators': 8},
 mean: 0.96402, std: 0.01220, params: {'learning_rate': 0.1, 'n_estimators': 9},
 mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.2, 'n_estimators': 1},
 mean: 0.95920, std: 0.00824, params: {'learning_rate': 0.2, 'n_estimators': 2},
 mean: 0.97181, std: 0.01766, params: {'learning_rate': 0.2, 'n_estimators': 3},
 mean: 0.95966, std: 0.01175, params: {'learning_rate': 0.2, 'n_estimators': 4},
 mean: 0.97570, std: 0.01759, params: {'learning_rate': 0.2, 'n_estimators': 5},
 mean: 0.97937, std: 0.01934, params: {'learning_rate': 0.2, 'n_estimators': 6},
 mean: 0.98212, std: 0.00940, params: {'learning_rate': 0.2, 'n_estimators': 7},
 mean: 0.98441, std: 0.00484, params: {'learning_rate': 0.2, 'n_estimators': 8},
 mean: 0.97914, std: 0.00751, params: {'learning_rate': 0.2, 'n_estimators': 9},
 mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.3, 'n_estimators': 1},
 mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.3, 'n_estimators': 2},
 mean: 0.97800, std: 0.00585, params: {'learning_rate': 0.3, 'n_estimators': 3},
 mean: 0.96081, std: 0.00956, params: {'learning_rate': 0.3, 'n_estimators': 4},
 mean: 0.98441, std: 0.00320, params: {'learning_rate': 0.3, 'n_estimators': 5},
 mean: 0.97937, std: 0.00397, params: {'learning_rate': 0.3, 'n_estimators': 6},
 mean: 0.98556, std: 0.00426, params: {'learning_rate': 0.3, 'n_estimators': 7},
 mean: 0.97823, std: 0.00579, params: {'learning_rate': 0.3, 'n_estimators': 8},
 mean: 0.97983, std: 0.00604, params: {'learning_rate': 0.3, 'n_estimators': 9},
 mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.4, 'n_estimators': 1},
 mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.4, 'n_estimators': 2},
 mean: 0.97800, std: 0.00585, params: {'learning_rate': 0.4, 'n_estimators': 3},
 mean: 0.96768, std: 0.00711, params: {'learning_rate': 0.4, 'n_estimators': 4},
 mean: 0.97548, std: 0.00642, params: {'learning_rate': 0.4, 'n_estimators': 5},
 mean: 0.97479, std: 0.00660, params: {'learning_rate': 0.4, 'n_estimators': 6},
 mean: 0.98419, std: 0.00493, params: {'learning_rate': 0.4, 'n_estimators': 7},
 mean: 0.99083, std: 0.00576, params: {'learning_rate': 0.4, 'n_estimators': 8},
 mean: 0.99335, std: 0.00255, params: {'learning_rate': 0.4, 'n_estimators': 9},
 mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.5, 'n_estimators': 1},
 mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.5, 'n_estimators': 2},
 mean: 0.97593, std: 0.00416, params: {'learning_rate': 0.5, 'n_estimators': 3},
 mean: 0.97112, std: 0.00926, params: {'learning_rate': 0.5, 'n_estimators': 4},
 mean: 0.98694, std: 0.00395, params: {'learning_rate': 0.5, 'n_estimators': 5},
 mean: 0.98143, std: 0.00603, params: {'learning_rate': 0.5, 'n_estimators': 6},
 mean: 0.99198, std: 0.00507, params: {'learning_rate': 0.5, 'n_estimators': 7},
 mean: 0.99404, std: 0.00443, params: {'learning_rate': 0.5, 'n_estimators': 8},
 mean: 0.99862, std: 0.00134, params: {'learning_rate': 0.5, 'n_estimators': 9},
 mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.6, 'n_estimators': 1},
 mean: 0.95554, std: 0.00936, params: {'learning_rate': 0.6, 'n_estimators': 2},
 mean: 0.97548, std: 0.00642, params: {'learning_rate': 0.6, 'n_estimators': 3},
 mean: 0.97410, std: 0.00751, params: {'learning_rate': 0.6, 'n_estimators': 4},
 mean: 0.98304, std: 0.00653, params: {'learning_rate': 0.6, 'n_estimators': 5},
 mean: 0.99244, std: 0.00792, params: {'learning_rate': 0.6, 'n_estimators': 6},
 mean: 0.99771, std: 0.00218, params: {'learning_rate': 0.6, 'n_estimators': 7},
 mean: 0.99794, std: 0.00222, params: {'learning_rate': 0.6, 'n_estimators': 8},
 mean: 0.99862, std: 0.00134, params: {'learning_rate': 0.6, 'n_estimators': 9},
 mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.7, 'n_estimators': 1},
 mean: 0.97387, std: 0.01725, params: {'learning_rate': 0.7, 'n_estimators': 2},
 mean: 0.97823, std: 0.00610, params: {'learning_rate': 0.7, 'n_estimators': 3},
 mean: 0.97983, std: 0.00726, params: {'learning_rate': 0.7, 'n_estimators': 4},
 mean: 0.99060, std: 0.00275, params: {'learning_rate': 0.7, 'n_estimators': 5},
 mean: 0.99427, std: 0.00162, params: {'learning_rate': 0.7, 'n_estimators': 6},
 mean: 0.99679, std: 0.00183, params: {'learning_rate': 0.7, 'n_estimators': 7},
 mean: 0.99702, std: 0.00186, params: {'learning_rate': 0.7, 'n_estimators': 8},
 mean: 0.99702, std: 0.00186, params: {'learning_rate': 0.7, 'n_estimators': 9},
 mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.8, 'n_estimators': 1},
 mean: 0.97227, std: 0.00856, params: {'learning_rate': 0.8, 'n_estimators': 2},
 mean: 0.98075, std: 0.00708, params: {'learning_rate': 0.8, 'n_estimators': 3},
 mean: 0.98579, std: 0.00674, params: {'learning_rate': 0.8, 'n_estimators': 4},
 mean: 0.99404, std: 0.00577, params: {'learning_rate': 0.8, 'n_estimators': 5},
 mean: 0.99794, std: 0.00255, params: {'learning_rate': 0.8, 'n_estimators': 6},
 mean: 0.99885, std: 0.00102, params: {'learning_rate': 0.8, 'n_estimators': 7},
 mean: 0.99908, std: 0.00086, params: {'learning_rate': 0.8, 'n_estimators': 8},
 mean: 0.99862, std: 0.00134, params: {'learning_rate': 0.8, 'n_estimators': 9},
 mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.9, 'n_estimators': 1},
 mean: 0.97937, std: 0.00397, params: {'learning_rate': 0.9, 'n_estimators': 2},
 mean: 0.98900, std: 0.00809, params: {'learning_rate': 0.9, 'n_estimators': 3},
 mean: 0.98487, std: 0.00689, params: {'learning_rate': 0.9, 'n_estimators': 4},
 mean: 0.99496, std: 0.00438, params: {'learning_rate': 0.9, 'n_estimators': 5},
 mean: 0.99565, std: 0.00302, params: {'learning_rate': 0.9, 'n_estimators': 6},
 mean: 0.99931, std: 0.00056, params: {'learning_rate': 0.9, 'n_estimators': 7},
 mean: 0.99817, std: 0.00200, params: {'learning_rate': 0.9, 'n_estimators': 8},
 mean: 0.99817, std: 0.00200, params: {'learning_rate': 0.9, 'n_estimators': 9},
 mean: 0.95576, std: 0.00954, params: {'learning_rate': 1.0, 'n_estimators': 1},
 mean: 0.97937, std: 0.00397, params: {'learning_rate': 1.0, 'n_estimators': 2},
 mean: 0.98969, std: 0.00763, params: {'learning_rate': 1.0, 'n_estimators': 3},
 mean: 0.98648, std: 0.00519, params: {'learning_rate': 1.0, 'n_estimators': 4},
 mean: 0.99450, std: 0.00197, params: {'learning_rate': 1.0, 'n_estimators': 5},
 mean: 0.99633, std: 0.00152, params: {'learning_rate': 1.0, 'n_estimators': 6},
 mean: 0.99817, std: 0.00200, params: {'learning_rate': 1.0, 'n_estimators': 7},
 mean: 0.99931, std: 0.00056, params: {'learning_rate': 1.0, 'n_estimators': 8},
 mean: 0.99931, std: 0.00056, params: {'learning_rate': 1.0, 'n_estimators': 9}]
clf.best_score_
0.9993123997249599
clf.best_params_
{'learning_rate': 0.9, 'n_estimators': 7}
pred_val = clf.predict(X_validation)
print("Validation Accuracy: %.2f%%" % (accuracy_score(y_validation, [round(x) for x in pred_val])))
Validation Accuracy: 1.00%
pred_test = clf.predict(X_test)
print("Test Accuracy: %.2f%%" % (accuracy_score(y_test, [round(x) for x in pred_test])))
  • 4
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值