MNIST2_LGB_XGB训练预测

针对MNIST数据集进行XGB\LGB模型训练和预测
部分脚本如下: 完整脚本见笔者github

lgb_param = {
    'boosting': 'gbdt',
    'num_iterations': 145,
    'num_threads' : 8, 
    'verbosity': 0,
    'learning_rate': 0.2,
    'max_depth' : 10,
    'num_leaves' : 8,
    'subsample' : 0.75,
    'subsample_freq': 5,
    'colsample_bytree' : 1,
    'reg_alpha': 1.5,
    'reg_lambda': 0.75,
    'objective': 'multiclass',
    'num_class': 10,
    'metric': 'multi_logloss',
    'early_stopping': 25
    # 'device': 'gpu',
    # 'gpu_platform_id': 0,
    # 'gpu_device_id': 0
}

xgb_param = {
    'booster': 'gbtree',
    'tree_method':'gpu_hist',
    'num_rounds': 160,
    'nthread' : 8, 
    'silent' : 1,
    'learning_rate': 0.2,
    'max_depth' : 10,
    'num_leaves' : 8,
    'subsample' : 0.75,
    'colsample_bytree' : 1,
    'reg_alpha': 1.5,
    'reg_lambda': 0.75,
    'objective': 'multi:softprob',
    'num_class': 10,
    'metric': 'mlogloss',
    'early_stopping': 25
}


@clock
def lgb_xgb_train(model, param, tr, te ):
    if model.__name__ == 'lightgbm':
        trdt = model.Dataset(data=tr.iloc[:, :-1].values, label=tr.iloc[:, -1].values)
        tedt = model.Dataset(data=te.iloc[:, :-1].values, label=te.iloc[:, -1].values)
        clf_model = model.train(param, trdt, valid_sets=[trdt, tedt] ,verbose_eval = 20)
        pred = np.argmax(clf_model.predict(te.iloc[:, :-1].values, num_iteration=clf_model.best_iteration ), axis=1)

    else:
        trdt = model.DMatrix(data=tr.iloc[:, :-1].values, label=tr.iloc[:, -1].values)
        tedt = model.DMatrix(data=te.iloc[:, :-1].values, label=te.iloc[:, -1].values)
        clf_model = model.train(param, trdt, evals=[(trdt, 'train'), (tedt, 'test')], verbose_eval = 20)
        pred = np.argmax(clf_model.predict(tedt, ntree_limit=-1), axis=1)
    
    y_te =  te.iloc[:, -1].values
    acc_ = sum(pred == y_te)/len(y_te) * 100
    return f'model: {model.__name__}, acc: {acc_:.2f}'

if __name__ == '__main__':
    mnistdf = get_ministdata()
    te_index = mnistdf.sample(frac=0.8).index.tolist()
    mnist_te = mnistdf.loc[te_index, :]
    mnist_tr = mnistdf.loc[~mnistdf.index.isin(te_index), :]
    print('train xgb ...')
    resxgb = lgb_xgb_train(xgb, xgb_param, mnist_tr, mnist_te)
    print('train lgb ...')
    reslgb = lgb_xgb_train(lgb, lgb_param, mnist_tr, mnist_te)

  • 结果如下
train xgb ...
[0]     train-merror:0.078143   test-merror:0.144911
[9]     train-merror:0.013643   test-merror:0.070464
lgb_xgb_train, take_time:37.27306s >> model: xgboost, acc: 92.95
train lgb ...
Training until validation scores don't improve for 25 rounds.
[20]    training's multi_logloss: 0.360737      valid_1's multi_logloss: 0.419623
[40]    training's multi_logloss: 0.178201      valid_1's multi_logloss: 0.259657
[60]    training's multi_logloss: 0.110021      valid_1's multi_logloss: 0.206787
[80]    training's multi_logloss: 0.0729135     valid_1's multi_logloss: 0.180611
[100]   training's multi_logloss: 0.051499      valid_1's multi_logloss: 0.16564
[120]   training's multi_logloss: 0.0381409     valid_1's multi_logloss: 0.156427
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[140]   training's multi_logloss: 0.0299155     valid_1's multi_logloss: 0.151268
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
Did not meet early stopping. Best iteration is:
[145]   training's multi_logloss: 0.0283099     valid_1's multi_logloss: 0.150082
lgb_xgb_train, take_time:27.31041s >> model: lightgbm, acc: 95.43
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Scc_hy

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值