模型训练与预测
1、导入所需模块
2、缺失值处理,特正规一会,类别特征转化等
3、训练模型,选择合适的机器学习模型,利用训练集对模型进行训练
4、预测结果
随机森林
随机森林参数介绍:https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html#sklearn.ensemble.RandomForestClassifier
随机森林属于集成学习,准确率较高,能够有效的运行在大数据集上,处理具有高位特征的输入样本,评估特征重要度
sklearn调用的随机森林分类树的预测算法
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
#导入数据
iris = datasets.load_iris()
feature = iris.feature_names
X = iris.data
y = iris.target
#随机森林
clf = RandomForestClassifier(n_estimators = 200)
train_X, test_X, train_y, test_y = train_test_split(X, y, test_size = 0.1, random_state = 5)
clf.fit(train_X, train_y)
test_pred = clf.predict(test_X)
#特征的重要性查看
print(str(feature) + '\n' + str(clf.feature_importances_))
#F1_score用于模型评价
score = f1_score(test_y, test_pred, average = 'macro')
print("随机森林-macro:",score)
score=f1_score(test_y,test_pred, average='weighted')
print("随机森林-weighted:", score)
[‘sepal length (cm)’, ‘sepal width (cm)’, ‘petal length (cm)’, ‘petal width (cm)’]
[0.10214829 0.02506777 0.42989154 0.4428924 ]
随机森林-macro: 0.818181818181818
随机森林-weighted: 0.8
lightGBM模型
lightGBM模型学习资料:https://mp.weixin.qq.com/s/64xfT9WIgF3yEExpSxyshQ
LightGBM(Light Gradient Boosting Machine)是一个实现GBDT算法的框架,支持高效率的并行训练,并且具有更快的训练速度、更低的内存消耗、更好的准确率、支持分布式可以快速处理海量数据等优点。
#Lightgbm
import lightgbm as lgb
from sklearn import datasets
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score
import matplotlib.pyplot as plt
#加载数据
iris = datasets.load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size = 0.3)
#转化为DataSet数据格式
train_data = lgb.Dataset(X_train, label = y_train)
validation_data = lgb.Dataset(X_test, label = y_test)
#参数
results = {}
params = {'learning_rate': 0.1, 'lambda_l1': 0.1, 'lambda_l2': 0.9, 'max_depth': 1, 'objective': 'multiclass', 'num_class': 3, 'verbose': -1}
#模型训练
gbm = lgb.train(params, train_data, valid_sets = (validation_data, train_data), valid_names = ('validate','train'), evals_result = results)
#模型预测
y_pred_test = gbm.predict(X_test)
y_pred_data = gbm.predict(X_train)
y_pred_data = [list(x).index(max(x)) for x in y_pred_data]
y_pred_test = [list(x).index(max(x)) for x in y_pred_test]
#模型评估
print(accuracy_score(y_test, y_pred_test))
print('训练集', f1_score(y_train, y_pred_data, average = 'macro'))
print('验证集', f1_score(y_test, y_pred_test, average = 'macro'))
lgb.plot_metric(results)
结果如下
[1] train's multi_logloss: 0.976938 validate's multi_logloss: 0.981023
[2] train's multi_logloss: 0.875592 validate's multi_logloss: 0.88226
[3] train's multi_logloss: 0.788956 validate's multi_logloss: 0.802856
[4] train's multi_logloss: 0.713724 validate's multi_logloss: 0.734451
[5] train's multi_logloss: 0.649018 validate's multi_logloss: 0.669279
[6] train's multi_logloss: 0.591531 validate's multi_logloss: 0.62037
[7] train's multi_logloss: 0.541379 validate's multi_logloss: 0.573368
[8] train's multi_logloss: 0.496263 validate's multi_logloss: 0.533024
[9] train's multi_logloss: 0.456348 validate's multi_logloss: 0.50304
[10] train's multi_logloss: 0.420516 validate's multi_logloss: 0.466105
[11] train's multi_logloss: 0.388361 validate's multi_logloss: 0.444444
[12] train's multi_logloss: 0.359458 validate's multi_logloss: 0.414478
[13] train's multi_logloss: 0.333213 validate's multi_logloss: 0.393904
[14] train's multi_logloss: 0.30962 validate's multi_logloss: 0.378558
[15] train's multi_logloss: 0.28816 validate's multi_logloss: 0.356121
[16] train's multi_logloss: 0.268664 validate's multi_logloss: 0.343697
[17] train's multi_logloss: 0.250909 validate's multi_logloss: 0.329402
[18] train's multi_logloss: 0.234768 validate's multi_logloss: 0.318374
[19] train's multi_logloss: 0.220133 validate's multi_logloss: 0.309139
[20] train's multi_logloss: 0.206499 validate's multi_logloss: 0.297277
[21] train's multi_logloss: 0.194156 validate's multi_logloss: 0.287945
[22] train's multi_logloss: 0.182832 validate's multi_logloss: 0.284079
[23] train's multi_logloss: 0.172406 validate's multi_logloss: 0.274423
[24] train's multi_logloss: 0.162929 validate's multi_logloss: 0.270819
[25] train's multi_logloss: 0.154074 validate's multi_logloss: 0.266948
[26] train's multi_logloss: 0.14596 validate's multi_logloss: 0.259182
[27] train's multi_logloss: 0.138531 validate's multi_logloss: 0.256827
[28] train's multi_logloss: 0.131556 validate's multi_logloss: 0.254639
[29] train's multi_logloss: 0.125138 validate's multi_logloss: 0.250006
[30] train's multi_logloss: 0.119295 validate's multi_logloss: 0.244976
[31] train's multi_logloss: 0.113977 validate's multi_logloss: 0.241169
[32] train's multi_logloss: 0.10876 validate's multi_logloss: 0.240898
[33] train's multi_logloss: 0.103932 validate's multi_logloss: 0.239186
[34] train's multi_logloss: 0.099617 validate's multi_logloss: 0.236112
[35] train's multi_logloss: 0.0954041 validate's multi_logloss: 0.236957
[36] train's multi_logloss: 0.0915208 validate's multi_logloss: 0.237267
[37] train's multi_logloss: 0.0880111 validate's multi_logloss: 0.236139
[38] train's multi_logloss: 0.0845806 validate's multi_logloss: 0.236863
[39] train's multi_logloss: 0.081451 validate's multi_logloss: 0.237791
[40] train's multi_logloss: 0.0785937 validate's multi_logloss: 0.23897
[41] train's multi_logloss: 0.0759003 validate's multi_logloss: 0.238387
[42] train's multi_logloss: 0.0733256 validate's multi_logloss: 0.239295
[43] train's multi_logloss: 0.0707755 validate's multi_logloss: 0.241507
[44] train's multi_logloss: 0.0684962 validate's multi_logloss: 0.24261
[45] train's multi_logloss: 0.0662404 validate's multi_logloss: 0.244052
[46] train's multi_logloss: 0.0642087 validate's multi_logloss: 0.246097
[47] train's multi_logloss: 0.0623323 validate's multi_logloss: 0.247787
[48] train's multi_logloss: 0.0605096 validate's multi_logloss: 0.246881
[49] train's multi_logloss: 0.0588124 validate's multi_logloss: 0.248425
[50] train's multi_logloss: 0.0571935 validate's multi_logloss: 0.250581
[51] train's multi_logloss: 0.0556659 validate's multi_logloss: 0.252133
[52] train's multi_logloss: 0.0542134 validate's multi_logloss: 0.254299
[53] train's multi_logloss: 0.0528306 validate's multi_logloss: 0.255876
[54] train's multi_logloss: 0.0515196 validate's multi_logloss: 0.257431
[55] train's multi_logloss: 0.0501929 validate's multi_logloss: 0.257479
[56] train's multi_logloss: 0.0489915 validate's multi_logloss: 0.259634
[57] train's multi_logloss: 0.0478563 validate's multi_logloss: 0.261256
[58] train's multi_logloss: 0.0467488 validate's multi_logloss: 0.263452
[59] train's multi_logloss: 0.0456721 validate's multi_logloss: 0.26314
[60] train's multi_logloss: 0.0446422 validate's multi_logloss: 0.262897
[61] train's multi_logloss: 0.0436813 validate's multi_logloss: 0.264889
[62] train's multi_logloss: 0.042755 validate's multi_logloss: 0.265027
[63] train's multi_logloss: 0.0418598 validate's multi_logloss: 0.267001
[64] train's multi_logloss: 0.0410012 validate's multi_logloss: 0.268787
[65] train's multi_logloss: 0.0401795 validate's multi_logloss: 0.270944
[66] train's multi_logloss: 0.0393914 validate's multi_logloss: 0.272873
[67] train's multi_logloss: 0.0386258 validate's multi_logloss: 0.272925
[68] train's multi_logloss: 0.0379039 validate's multi_logloss: 0.275053
[69] train's multi_logloss: 0.0371902 validate's multi_logloss: 0.276963
[70] train's multi_logloss: 0.0365181 validate's multi_logloss: 0.278877
[71] train's multi_logloss: 0.0358715 validate's multi_logloss: 0.280588
[72] train's multi_logloss: 0.0352353 validate's multi_logloss: 0.282603
[73] train's multi_logloss: 0.0346372 validate's multi_logloss: 0.284594
[74] train's multi_logloss: 0.0340456 validate's multi_logloss: 0.286232
[75] train's multi_logloss: 0.0334864 validate's multi_logloss: 0.288148
[76] train's multi_logloss: 0.0329445 validate's multi_logloss: 0.289775
[77] train's multi_logloss: 0.0324157 validate's multi_logloss: 0.291646
[78] train's multi_logloss: 0.031911 validate's multi_logloss: 0.29161
[79] train's multi_logloss: 0.0314163 validate's multi_logloss: 0.293302
[80] train's multi_logloss: 0.0309399 validate's multi_logloss: 0.294895
[81] train's multi_logloss: 0.0304758 validate's multi_logloss: 0.29668
[82] train's multi_logloss: 0.0300298 validate's multi_logloss: 0.298222
[83] train's multi_logloss: 0.0295942 validate's multi_logloss: 0.299954
[84] train's multi_logloss: 0.0291753 validate's multi_logloss: 0.300086
[85] train's multi_logloss: 0.0287618 validate's multi_logloss: 0.301625
[86] train's multi_logloss: 0.0283652 validate's multi_logloss: 0.303182
[87] train's multi_logloss: 0.0279776 validate's multi_logloss: 0.304707
[88] train's multi_logloss: 0.0276011 validate's multi_logloss: 0.305106
[89] train's multi_logloss: 0.0272367 validate's multi_logloss: 0.306678
[90] train's multi_logloss: 0.026883 validate's multi_logloss: 0.30822
[91] train's multi_logloss: 0.0265358 validate's multi_logloss: 0.309666
[92] train's multi_logloss: 0.0262009 validate's multi_logloss: 0.309765
[93] train's multi_logloss: 0.0258715 validate's multi_logloss: 0.311171
[94] train's multi_logloss: 0.0255548 validate's multi_logloss: 0.312661
[95] train's multi_logloss: 0.0252423 validate's multi_logloss: 0.312928
[96] train's multi_logloss: 0.0249416 validate's multi_logloss: 0.314297
[97] train's multi_logloss: 0.0246452 validate's multi_logloss: 0.315116
[98] train's multi_logloss: 0.0243606 validate's multi_logloss: 0.31647
[99] train's multi_logloss: 0.0240774 validate's multi_logloss: 0.317843
[100] train's multi_logloss: 0.0238036 validate's multi_logloss: 0.318235
0.9333333333333333
训练集 1.0
验证集 0.9332591768631814
[1] train's multi_logloss: 0.981573 validate's multi_logloss: 0.985608
[2] train's multi_logloss: 0.883927 validate's multi_logloss: 0.89032
[3] train's multi_logloss: 0.800254 validate's multi_logloss: 0.811111
[4] train's multi_logloss: 0.727456 validate's multi_logloss: 0.747084
[5] train's multi_logloss: 0.664656 validate's multi_logloss: 0.683869
[6] train's multi_logloss: 0.60884 validate's multi_logloss: 0.636064
[7] train's multi_logloss: 0.560072 validate's multi_logloss: 0.590141
[8] train's multi_logloss: 0.516133 validate's multi_logloss: 0.550807
[9] train's multi_logloss: 0.477245 validate's multi_logloss: 0.517379
[10] train's multi_logloss: 0.442304 validate's multi_logloss: 0.485167
[11] train's multi_logloss: 0.410901 validate's multi_logloss: 0.463524
[12] train's multi_logloss: 0.382674 validate's multi_logloss: 0.435015
[13] train's multi_logloss: 0.357052 validate's multi_logloss: 0.414516
[14] train's multi_logloss: 0.334035 validate's multi_logloss: 0.39999
[15] train's multi_logloss: 0.31286 validate's multi_logloss: 0.377932
[16] train's multi_logloss: 0.293857 validate's multi_logloss: 0.365563
[17] train's multi_logloss: 0.276418 validate's multi_logloss: 0.350061
[18] train's multi_logloss: 0.260565 validate's multi_logloss: 0.337537
[19] train's multi_logloss: 0.246165 validate's multi_logloss: 0.327766
[20] train's multi_logloss: 0.2328 validate's multi_logloss: 0.316106
[21] train's multi_logloss: 0.220709 validate's multi_logloss: 0.308427
[22] train's multi_logloss: 0.209576 validate's multi_logloss: 0.298055
[23] train's multi_logloss: 0.199358 validate's multi_logloss: 0.292903
[24] train's multi_logloss: 0.18999 validate's multi_logloss: 0.287517
[25] train's multi_logloss: 0.1813 validate's multi_logloss: 0.279557
[26] train's multi_logloss: 0.173332 validate's multi_logloss: 0.275507
[27] train's multi_logloss: 0.165919 validate's multi_logloss: 0.271781
[28] train's multi_logloss: 0.159137 validate's multi_logloss: 0.265719
[29] train's multi_logloss: 0.152858 validate's multi_logloss: 0.264156
[30] train's multi_logloss: 0.147033 validate's multi_logloss: 0.26041
[31] train's multi_logloss: 0.141605 validate's multi_logloss: 0.256151
[32] train's multi_logloss: 0.136616 validate's multi_logloss: 0.255597
[33] train's multi_logloss: 0.131889 validate's multi_logloss: 0.252857
[34] train's multi_logloss: 0.127604 validate's multi_logloss: 0.24959
[35] train's multi_logloss: 0.123542 validate's multi_logloss: 0.249568
[36] train's multi_logloss: 0.119815 validate's multi_logloss: 0.246283
[37] train's multi_logloss: 0.116406 validate's multi_logloss: 0.243429
[38] train's multi_logloss: 0.113095 validate's multi_logloss: 0.243128
[39] train's multi_logloss: 0.1101 validate's multi_logloss: 0.241747
[40] train's multi_logloss: 0.107342 validate's multi_logloss: 0.241385
[41] train's multi_logloss: 0.104745 validate's multi_logloss: 0.242009
[42] train's multi_logloss: 0.102293 validate's multi_logloss: 0.239914
[43] train's multi_logloss: 0.0998498 validate's multi_logloss: 0.239593
[44] train's multi_logloss: 0.0975656 validate's multi_logloss: 0.239634
[45] train's multi_logloss: 0.095428 validate's multi_logloss: 0.23973
[46] train's multi_logloss: 0.09341 validate's multi_logloss: 0.239593
[47] train's multi_logloss: 0.0915987 validate's multi_logloss: 0.238425
[48] train's multi_logloss: 0.0899355 validate's multi_logloss: 0.238867
[49] train's multi_logloss: 0.0883439 validate's multi_logloss: 0.239669
[50] train's multi_logloss: 0.0868249 validate's multi_logloss: 0.238542
[51] train's multi_logloss: 0.0853989 validate's multi_logloss: 0.239364
[52] train's multi_logloss: 0.0840528 validate's multi_logloss: 0.239413
[53] train's multi_logloss: 0.0827655 validate's multi_logloss: 0.238817
[54] train's multi_logloss: 0.0815579 validate's multi_logloss: 0.239628
[55] train's multi_logloss: 0.0804065 validate's multi_logloss: 0.23992
[56] train's multi_logloss: 0.0793187 validate's multi_logloss: 0.240284
[57] train's multi_logloss: 0.0782935 validate's multi_logloss: 0.24064
[58] train's multi_logloss: 0.0773119 validate's multi_logloss: 0.241139
[59] train's multi_logloss: 0.0763733 validate's multi_logloss: 0.24082
[60] train's multi_logloss: 0.0754934 validate's multi_logloss: 0.241047
[61] train's multi_logloss: 0.0746551 validate's multi_logloss: 0.241338
[62] train's multi_logloss: 0.0738464 validate's multi_logloss: 0.241815
[63] train's multi_logloss: 0.073084 validate's multi_logloss: 0.242101
[64] train's multi_logloss: 0.0723531 validate's multi_logloss: 0.24226
[65] train's multi_logloss: 0.0716587 validate's multi_logloss: 0.242809
[66] train's multi_logloss: 0.0709938 validate's multi_logloss: 0.243073
[67] train's multi_logloss: 0.0703601 validate's multi_logloss: 0.243336
[68] train's multi_logloss: 0.0697568 validate's multi_logloss: 0.243481
[69] train's multi_logloss: 0.0691771 validate's multi_logloss: 0.243966
[70] train's multi_logloss: 0.0686266 validate's multi_logloss: 0.24434
[71] train's multi_logloss: 0.0680923 validate's multi_logloss: 0.24448
[72] train's multi_logloss: 0.0675908 validate's multi_logloss: 0.244918
[73] train's multi_logloss: 0.0671031 validate's multi_logloss: 0.245043
[74] train's multi_logloss: 0.066642 validate's multi_logloss: 0.24534
[75] train's multi_logloss: 0.0661947 validate's multi_logloss: 0.24546
[76] train's multi_logloss: 0.0657717 validate's multi_logloss: 0.245846
[77] train's multi_logloss: 0.0653627 validate's multi_logloss: 0.246028
[78] train's multi_logloss: 0.0649723 validate's multi_logloss: 0.246223
[79] train's multi_logloss: 0.0645964 validate's multi_logloss: 0.246321
[80] train's multi_logloss: 0.0642376 validate's multi_logloss: 0.246508
[81] train's multi_logloss: 0.0638923 validate's multi_logloss: 0.246662
[82] train's multi_logloss: 0.0635608 validate's multi_logloss: 0.246979
[83] train's multi_logloss: 0.0632425 validate's multi_logloss: 0.247059
[84] train's multi_logloss: 0.0629365 validate's multi_logloss: 0.247227
[85] train's multi_logloss: 0.0626434 validate's multi_logloss: 0.247345
[86] train's multi_logloss: 0.0623601 validate's multi_logloss: 0.247621
[87] train's multi_logloss: 0.0620874 validate's multi_logloss: 0.247688
[88] train's multi_logloss: 0.0618245 validate's multi_logloss: 0.247881
[89] train's multi_logloss: 0.0615721 validate's multi_logloss: 0.247934
[90] train's multi_logloss: 0.0613283 validate's multi_logloss: 0.248175
[91] train's multi_logloss: 0.0610936 validate's multi_logloss: 0.248227
[92] train's multi_logloss: 0.0608669 validate's multi_logloss: 0.248395
[93] train's multi_logloss: 0.0606489 validate's multi_logloss: 0.248442
[94] train's multi_logloss: 0.0604381 validate's multi_logloss: 0.248563
[95] train's multi_logloss: 0.060235 validate's multi_logloss: 0.248645
[96] train's multi_logloss: 0.0600386 validate's multi_logloss: 0.248843
[97] train's multi_logloss: 0.0598487 validate's multi_logloss: 0.24889
[98] train's multi_logloss: 0.0596653 validate's multi_logloss: 0.249037
[99] train's multi_logloss: 0.0594871 validate's multi_logloss: 0.249088
[100] train's multi_logloss: 0.0593146 validate's multi_logloss: 0.249275
0.9333333333333333
训练集 1.0
验证集 0.9332591768631814
验证集的损失大于训练集的损失,所以模型出现过拟合:
处理过拟合的方法: 使用 lambda_l1, lambda_l2 和 min_gain_to_split 来使用正则;使用较小的 max_bin‘,使用较小的 max_bin等
此处将lambda_l1 = 0.9
params = {'learning_rate': 0.1, 'lambda_l1': 0.9, 'lambda_l2': 0.9, 'max_depth': 1, 'objective': 'multiclass', 'num_class': 3, 'verbose': -1}
#模型训练
gbm = lgb.train(params, train_data, valid_sets = (validation_data, train_data), valid_names = ('validate','train'), evals_result = results)
#模型预测
y_pred_test = gbm.predict(X_test)
y_pred_data = gbm.predict(X_train)
y_pred_data = [list(x).index(max(x)) for x in y_pred_data]
y_pred_test = [list(x).index(max(x)) for x in y_pred_test]
#模型评估
print(accuracy_score(y_test, y_pred_test))
print('训练集', f1_score(y_train, y_pred_data, average = 'macro'))
print('验证集', f1_score(y_test, y_pred_test, average = 'macro'))
lgb.plot_metric(results)
#绘制特征重要度
lgb.plot_importance(gbm, importance_type = 'split' )
plt.show()
Xgboost模型
xgboost模型学习网站 https://mp.weixin.qq.com/s/AAKPSIHk1iUqCeUibrORqQ
#xgboost
from sklearn.datasets import load_iris
import xgboost as xgb
from xgboost import plot_importance
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score # 准确率
# 加载样本数据集
iris = load_iris()
X,y = iris.data,iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234565) # 数据集分割
# 算法参数
params = {
'booster': 'gbtree',
'objective': 'multi:softmax',
'eval_metric':'mlogloss',
'num_class': 3,
'gamma': 0.1,
'max_depth': 6,
'lambda': 2,
'subsample': 0.7,
'colsample_bytree': 0.75,
'min_child_weight': 3,
'eta': 0.1,
'seed': 1,
'nthread': 4,
}
# plst = params.items()
train_data = xgb.DMatrix(X_train, y_train) # 生成数据集格式
num_rounds = 500
model = xgb.train(params, train_data) # xgboost模型训练
# 对测试集进行预测
dtest = xgb.DMatrix(X_test)
y_pred = model.predict(dtest)
# 计算准确率
F1_score = f1_score(y_test,y_pred,average='macro')
print("F1_score: %.2f%%" % (F1_score*100.0))
# 显示重要特征
plot_importance(model)
plt.show()
F1_score: 95.56%
交叉验证
交叉验证:将原始数据进行分组,一部分作为训练集,一部分作为验证集,主要有简单交叉验证,k折交叉验证,留一法交叉验证和留P法交叉验证:代码如下:
#交叉验证
#简单交叉验证
from sklearn.model_selection import train_test_split
from sklearn import datasets
#数据集导入
iris = datasets.load_iris()
feature = iris.feature_names
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.4, random_state = 0)
#K折交叉验证
from sklearn.model_selection import KFold
fols = KFold(n_splits = 10, shuffle = True)
#留一法交叉验证
from sklearn.model_selection import LeaveOneOut
loo = LeaveOneOut()
#留P法交叉验证
from sklearn.model_selection import LeavePOut
lpo = LeavePOut(p = 5)
参数调优
1、网格搜索:穷举搜索;
2、学习曲线:在训练集大小不同时通过绘制模型训练集与交叉验证上的准确率观察模型在新数据上的表现,判断模型的方差
网格搜索比较耗时,所以下述代码没有运行完成
#以Xgboost为例,该网格搜索代码示例如下
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import GridSearchCV
cancer = load_breast_cancer()
x = cancer.data[:50]
y = cancer.target[:50]
train_x, valid_x, train_y, valid_y = train_test_split(x, y, test_size=0.333, random_state=0) # 分训
# 这里不需要Dmatrix
parameters = {
'max_depth': [5, 10, 15, 20, 25],
'learning_rate': [0.01, 0.02, 0.05, 0.1, 0.15],
'n_estimators': [50, 100, 200, 300, 500],
'min_child_weight': [0, 2, 5, 10, 20],
'max_delta_step': [0, 0.2, 0.6, 1, 2],
'subsample': [0.6, 0.7, 0.8, 0.85, 0.95],
'colsample_bytree': [0.5, 0.6, 0.7, 0.8, 0.9],
'reg_alpha': [0, 0.25, 0.5, 0.75, 1],
'reg_lambda': [0.2, 0.4, 0.6, 0.8, 1],
'scale_pos_weight': [0.2, 0.4, 0.6, 0.8, 1]
}
xlf = xgb.XGBClassifier(max_depth=10,
learning_rate=0.01,
n_estimators=2000,
silent=True,
objective='binary:logistic',
nthread=-1,
gamma=0,
min_child_weight=1,
max_delta_step=0,
subsample=0.85,
colsample_bytree=0.7,
colsample_bylevel=1,
reg_alpha=0,
reg_lambda=1,
scale_pos_weight=1,
seed=1440,
missing=None)
# 有了gridsearch我们便不需要fit函数
gsearch = GridSearchCV(xlf, param_grid=parameters, scoring='accuracy', cv=3)
gsearch.fit(train_x, train_y)
print("Best score: %0.3f" % gsearch.best_score_)
print("Best parameters set:")
best_parameters = gsearch.best_estimator_.get_params()
for param_name in sorted(parameters.keys()):
print("\t%s: %r" % (param_name, best_parameters[param_name]))
智慧海洋代码
导包时候注意:from config import config→import config
```python
import config
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report, f1_score
from sklearn.model_selection import StratifiedKFold, KFold, train_test_split
import lightgbm as lgb
import os
import warnings
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
all_df = pd.read_csv(r"C:\Users\李\Desktop\datawheal\data\group_df.csv", index_col = 0)
use_train = all_df[all_df['label'] != -1]
use_test = all_df[all_df['label'] == -1] #label为-1时是测试集
use_feats = [c for c in use_train.columns if c not in ['ID', 'label']]
X_train, X_verify, y_train, y_verify = train_test_split(use_train[use_feats], use_train['label'], test_size = 0.3, random_state = 0)
特征重要度选择,
#特征重要度选择
selectFeatures = 200#控制特征数
earlyStopping = 100 #控制早停
select_num_boost_round = 1000#特征选择训练轮次
#设置基础参数
selfParam = { 'learning_rate':0.01, # 学习率
'boosting':'dart', # 算法类型, gbdt,dart
'objective':'multiclass', # 多分类
'metric':'None',
'num_leaves':32, #
'feature_fraction':0.7, # 训练特征比例
'bagging_fraction':0.8, # 训练样本比例
'min_data_in_leaf':30, # 叶子最小样本
'num_class': 3,
'max_depth':6, # 树的最大深度
'num_threads':8,#LightGBM 的线程数
'min_data_in_bin':30, # 单箱数据量
'max_bin':256, # 最大分箱数
'is_unbalance':True, # 非平衡样本
'train_metric':True,
'verbose':-1,
}
特征选择结果如下
select forward 200 features:[('pos_neq_zero_speed_q_40', 1783),
('lat_lon_countvec_1_x', 1771), ('rank2_mode_lat', 1737), ('pos_neq_zero_speed_median', 1379), ('pos_neq_zero_speed_q_60', 1369),
('lat_lon_tfidf_0_x', 1251), ('pos_neq_zero_speed_q_80', 1194),
('sample_tfidf_0_x', 1168), ('w2v_9_mean', 1134), ('lat_lon_tfidf_11_x', 963),
('rank3_mode_lat', 946), ('w2v_5_mean', 900), ('w2v_16_mean', 874),
('pos_neq_zero_speed_q_30', 866), ('w2v_12_mean', 862), ('pos_neq_zero_speed_q_70',
856), ('lat_lon_tfidf_9_x', 787), ('grad_tfidf_7_x', 772),
('pos_neq_zero_speed_q_90', 746), ('rank3_mode_cnt', 733), ('grad_tfidf_12_x',
729), ('w2v_4_mean', 697), ('sample_tfidf_14_x', 695), ('lat_lon_tfidf_4_x', 693),
('lat_min', 683), ('w2v_23_mean', 647), ('rank2_mode_lon', 631), ('w2v_26_mean',
626), ('rank1_mode_lon', 620), ('grad_tfidf_15_x', 607),
('speed_neq_zero_speed_q_90', 603), ('grad_tfidf_5_x', 572),
('lat_lon_countvec_22_x', 571), ('lat_lon_countvec_1_y', 565), ('w2v_13_mean',
557), ('w2v_27_mean', 550), ('grad_tfidf_2_x', 507), ('lat_lon_tfidf_20_x', 503),
('lat_lon_countvec_0_x', 499), ('lat_lon_countvec_18_x', 490),
('sample_tfidf_21_x', 488), ('grad_tfidf_14_x', 484), ('lat_lon_countvec_27_x',
470), ('w2v_22_mean', 466), ('lat_lon_tfidf_1_x', 461), ('direction_nunique',
460), ('lon_max', 457), ('w2v_15_mean', 441), ('grad_tfidf_23_x', 431),
('w2v_19_mean', 429), ('w2v_11_mean', 428), ('lat_lon_tfidf_29_x', 420),
('pos_neq_zero_lon_q_10', 417), ('w2v_3_mean', 411), ('lat_lon_tfidf_0_y', 407),
('sample_tfidf_29_x', 406), ('anchor_cnt', 404), ('grad_tfidf_8_x', 397),
('sample_tfidf_10_x', 397), ('sample_tfidf_12_x', 385), ('w2v_28_mean', 384),
('grad_tfidf_13_x', 381), ('direction_q_90', 380), ('speed_neq_zero_lon_min', 374), ('w2v_25_mean', 371), ('anchor_ratio', 367), ('lat_lon_tfidf_16_x', 367),
('rank1_mode_lat', 365), ('w2v_18_mean', 365), ('sample_tfidf_23_x', 364),
('lon_min', 354), ('grad_tfidf_0_x', 351), ('pos_neq_zero_lat_q_90', 341),
('w2v_20_mean', 341), ('sample_tfidf_4_x', 334), ('lat_lon_tfidf_23_x', 332),
('sample_tfidf_0_y', 328), ('pos_neq_zero_direction_q_90', 326),
('speed_neq_zero_direction_nunique', 326), ('sample_tfidf_19_x', 323),
('lat_lon_countvec_9_x', 319), ('pos_neq_zero_lon_q_90', 314), ('w2v_8_mean',
312), ('grad_tfidf_3_x', 309), ('lon_median', 305), ('pos_neq_zero_speed_q_20',
304), ('lat_lon_countvec_4_x', 304), ('lat_mean', 301), ('speed_neq_zero_lon_max',
301), ('lat_lon_tfidf_14_x', 301), ('speed_neq_zero_lat_min', 300),
('lat_lon_countvec_5_x', 296), ('speed_neq_zero_speed_q_80', 294),
('grad_tfidf_16_x', 293), ('rank3_mode_lon', 292), ('lat_lon_tfidf_18_x', 291),
('w2v_7_mean', 290), ('grad_tfidf_6_x', 285), ('grad_tfidf_20_x', 283),
('grad_tfidf_18_x', 282), ('w2v_0_mean', 280), ('grad_tfidf_21_x', 279),
('grad_tfidf_22_x', 273), ('sample_tfidf_24_x', 273), ('speed_q_90', 271),
('w2v_2_mean', 271), ('lat_max', 264), ('sample_tfidf_9_x', 264),
('grad_tfidf_11_x', 262), ('lon_q_20', 260), ('rank1_mode_cnt', 258),
('speed_max', 256), ('lat_lon_tfidf_12_x', 251), ('pos_neq_zero_lon_q_20', 248),
('lat_lon_tfidf_28_x', 242), ('speed_neq_zero_direction_q_60', 241),
('sample_tfidf_11_x', 241), ('w2v_17_mean', 241), ('sample_tfidf_13_x', 238),
('w2v_14_mean', 236), ('lat_nunique', 235),
('grad_tfidf_4_x', 234), ('w2v_21_mean', 234), ('sample_tfidf_5_x', 231), ('lat_lon_tfidf_9_y', 225), ('speed_neq_zero_lat_q_90', 222), ('direction_median', 221), ('sample_tfidf_17_x', 220), ('sample_tfidf_14_y', 216),
('lat_lon_tfidf_21_x', 215), ('lon_q_10', 214), ('lat_lon_tfidf_22_x', 214),
('grad_tfidf_26_x', 213), ('grad_tfidf_7_y', 213), ('w2v_29_mean', 212),
('pos_neq_zero_lat_q_80', 210), ('cnt', 209), ('lat_lon_tfidf_4_y', 208),
('direction_q_60', 204), ('sample_tfidf_18_x', 203), ('lat_lon_tfidf_11_y', 203),
('pos_neq_zero_lat_min', 202), ('pos_neq_zero_speed_mean', 201),
('speed_neq_zero_lat_q_70', 200), ('grad_tfidf_12_y', 198), ('sample_tfidf_20_x',
197), ('w2v_1_mean', 194), ('speed_neq_zero_lat_q_40', 193),
('pos_neq_zero_speed_max', 192), ('grad_tfidf_27_x', 192), ('grad_tfidf_15_y',
191), ('lat_lon_tfidf_19_x', 189), ('lat_median', 187), ('lat_lon_tfidf_15_x',
187), ('lat_q_20', 186), ('lat_q_70', 186), ('lon_q_70', 185), ('w2v_24_mean',
184), ('pos_neq_zero_lat_q_40', 183), ('grad_tfidf_25_x', 181), ('w2v_10_mean',
181), ('lon_mean', 180), ('sample_tfidf_27_x', 180), ('w2v_6_mean', 180),
('lat_lon_tfidf_24_x', 178), ('lat_lon_countvec_12_x', 178),
('pos_neq_zero_lat_mean', 177), ('speed_neq_zero_speed_q_70', 174),
('speed_neq_zero_direction_q_80', 172), ('rank2_mode_cnt', 172),
('speed_neq_zero_lat_nunique', 171), ('lat_lon_tfidf_2_x', 171),
('sample_tfidf_25_x', 170), ('lat_lon_tfidf_5_x', 169), ('lat_lon_countvec_26_x',
167), ('grad_tfidf_9_x', 166), ('lat_lon_countvec_28_x', 163),
('lat_lon_countvec_22_y', 163), ('sample_tfidf_1_x', 162),
('pos_neq_zero_direction_nunique', 161), ('pos_neq_zero_speed_q_10', 157),
('sample_tfidf_16_x', 155), ('speed_neq_zero_direction_q_90', 154),
('grad_tfidf_14_y', 153), ('lat_lon_tfidf_7_x', 151),
('pos_neq_zero_direction_q_80', 149), ('lat_q_80', 148), ('grad_tfidf_23_y', 148),
('lat_lon_countvec_11_x', 147), ('sample_tfidf_22_x', 146),
('speed_neq_zero_lat_max', 144), ('sample_tfidf_15_x', 144), ('grad_tfidf_2_y',
144), ('pos_neq_zero_lat_q_10', 142), ('lat_lon_tfidf_1_y', 142),
('lat_lon_countvec_16_x', 141), ('grad_tfidf_13_y', 138), ('lat_lon_countvec_29_x',
136), ('lat_lon_tfidf_29_y', 136), ('grad_tfidf_5_y', 136)]
利用贝叶斯优化超参数:
#model_feature是选择的超参数
model_feature = [k[0] for k in sort_feature_importance[:selectFeatures]]
#超参数优化
#参数空间
spaceParam = {'boosting': hp.choice('boosting',['gbdt','dart']),
'learning_rate':hp.loguniform('learning_rate', np.log(0.01), np.log(0.05)),
'num_leaves': hp.quniform('num_leaves', 3, 66, 3),
'feature_fraction': hp.uniform('feature_fraction', 0.7,1),
'min_data_in_leaf': hp.quniform('min_data_in_leaf', 10, 50,5),
'num_boost_round':hp.quniform('num_boost_round',500,2000,100),
'bagging_fraction':hp.uniform('bagging_fraction',0.6,1)
}
#最优参数编译
def getParam(param):
for k in ['num_leaves', 'min_data_in_leaf', 'num_boost_round']:
param[k] = int(float(param[k]))
for k in ['learning_rate', 'feature_fraction', 'bagging_fraction']:
param[k] = float(param[k])
if param['boosting'] == 0:
param['boosting'] = 'gbdt'
elif param['boosting'] == 1:
param['boosting'] = 'dart'
#添加固定参数
param['objective'] = 'multiclass'
param['max_depth'] = 7
param['num_threads'] = 8
param['is_unbalance'] = True
param['metric'] = 'None'
param['train_metric'] = True
param['verbose'] = -1
param['bagging_freq'] = 5
param['num_class'] = 3
param['feature_pre_filter'] = False
return param
#目标函数
def f1_score_eval(preds, valid_df):
labels = valid_df.get_label()
preds = np.argmax(preds.reshape(3, -1), axis = 0)
scores = f1_score(y_true = labels, y_pred = preds, average = 'macro')
return 'f1_score', scores, True
def lossFun(param):
param = getParam(param)
m = lgb.train(param, train_set = train_data, num_boost_round = param['num_boost_round'], valid_sets = [train_data, valid_data]
, valid_names = ['train', 'valid'], feature_name = features, feval = f1_score_eval, early_stopping_rounds = earlyStopping,
verbose_eval = False, keep_training_booster = True)
train_f1_score = m.best_score['train']['f1_score']
valid_f1_score = m.best_score['valid']['f1_score']
loss_f1_score = 1 - valid_f1_score
print(('训练集f1_score:{}, 测试集f1_score:{}, loss_f1_score:{}'.format(train_f1_score, valid_f1_score, loss_f1_score)))
return {'loss': loss_f1_score, 'params': param, 'status': STATUS_OK}
features = model_feature
train_data = lgb.Dataset(data = X_train[model_feature], label = y_train, feature_name = features)
valid_data = lgb.Dataset(data = X_verify[features], label = y_verify, feature_name = features, reference=train_data)
#搜索最优参数
best_param = fmin(fn = lossFun, space = spaceParam, algo = tpe.suggest, max_evals = 100, trials = Trials())
best_param = getParam(best_param)
print('Search best param:', best_param)
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
E:\anacoda\lib\site-packages\lightgbm\callback.py:186: UserWarning: Early stopping is not available in dart mode
warnings.warn('Early stopping is not available in dart mode')
训练集f1_score:1.0, 测试集f1_score:0.9285309559097713, loss_f1_score:0.0714690440902287
训练集f1_score:0.9963233527198608, 测试集f1_score:0.9095083779923133, loss_f1_score:0.09049162200768668
训练集f1_score:1.0, 测试集f1_score:0.9257391891204496, loss_f1_score:0.07426081087955039
训练集f1_score:1.0, 测试集f1_score:0.9198350308928772, loss_f1_score:0.08016496910712279
训练集f1_score:0.9992236245819818, 测试集f1_score:0.9167964902969308, loss_f1_score:0.08320350970306922
训练集f1_score:1.0, 测试集f1_score:0.9228775314361938, loss_f1_score:0.07712246856380622
训练集f1_score:0.9815257139740128, 测试集f1_score:0.9059852356592902, loss_f1_score:0.09401476434070977
训练集f1_score:0.9995123993009063, 测试集f1_score:0.9201072375043884, loss_f1_score:0.07989276249561161
训练集f1_score:1.0, 测试集f1_score:0.917180592411146, loss_f1_score:0.08281940758885398
训练集f1_score:1.0, 测试集f1_score:0.9230556252348961, loss_f1_score:0.07694437476510385
训练集f1_score:1.0, 测试集f1_score:0.9177203389359856, loss_f1_score:0.08227966106401441
训练集f1_score:0.998757147314354, 测试集f1_score:0.9104583820708552, loss_f1_score:0.08954161792914483
训练集f1_score:1.0, 测试集f1_score:0.914298851135159, loss_f1_score:0.08570114886484104
训练集f1_score:1.0, 测试集f1_score:0.9231479325133111, loss_f1_score:0.07685206748668894
训练集f1_score:1.0, 测试集f1_score:0.9224745430635202, loss_f1_score:0.07752545693647983
训练集f1_score:1.0, 测试集f1_score:0.9174189426487299, loss_f1_score:0.08258105735127008
训练集f1_score:0.9925304441826487, 测试集f1_score:0.906781252235943, loss_f1_score:0.09321874776405703
训练集f1_score:0.9631339775372375, 测试集f1_score:0.9066989618735376, loss_f1_score:0.0933010381264624
训练集f1_score:1.0, 测试集f1_score:0.9250232632818385, loss_f1_score:0.07497673671816152
训练集f1_score:1.0, 测试集f1_score:0.9203115079005767, loss_f1_score:0.0796884920994233
训练集f1_score:1.0, 测试集f1_score:0.9248984394690476, loss_f1_score:0.0751015605309524
训练集f1_score:1.0, 测试集f1_score:0.9215913930032361, loss_f1_score:0.07840860699676389
训练集f1_score:0.9440002150564252, 测试集f1_score:0.8935415971895097, loss_f1_score:0.10645840281049035
训练集f1_score:1.0, 测试集f1_score:0.926645936839296, loss_f1_score:0.07335406316070403
训练集f1_score:1.0, 测试集f1_score:0.9226316974274017, loss_f1_score:0.0773683025725983
训练集f1_score:1.0, 测试集f1_score:0.9242701223911126, loss_f1_score:0.07572987760888739
训练集f1_score:1.0, 测试集f1_score:0.9198254231957517, loss_f1_score:0.08017457680424833
训练集f1_score:1.0, 测试集f1_score:0.9256990122822063, loss_f1_score:0.0743009877177937
训练集f1_score:1.0, 测试集f1_score:0.9194102895600668, loss_f1_score:0.08058971043993324
训练集f1_score:0.9923324494323645, 测试集f1_score:0.9023188996860295, loss_f1_score:0.09768110031397048
训练集f1_score:1.0, 测试集f1_score:0.9207667553153351, loss_f1_score:0.07923324468466486
训练集f1_score:1.0, 测试集f1_score:0.9235684862240877, loss_f1_score:0.07643151377591229
训练集f1_score:1.0, 测试集f1_score:0.9159299981259025, loss_f1_score:0.08407000187409752
训练集f1_score:1.0, 测试集f1_score:0.9201677556683528, loss_f1_score:0.07983224433164715
训练集f1_score:1.0, 测试集f1_score:0.9214644211379112, loss_f1_score:0.07853557886208884
训练集f1_score:1.0, 测试集f1_score:0.9209037282718199, loss_f1_score:0.07909627172818012
训练集f1_score:1.0, 测试集f1_score:0.9126368594667446, loss_f1_score:0.08736314053325545
训练集f1_score:0.9752612340062626, 测试集f1_score:0.9007932634486341, loss_f1_score:0.09920673655136591
训练集f1_score:1.0, 测试集f1_score:0.9156230890670725, loss_f1_score:0.08437691093292754
训练集f1_score:1.0, 测试集f1_score:0.9205297016223192, loss_f1_score:0.07947029837768083
训练集f1_score:1.0, 测试集f1_score:0.9246093309118519, loss_f1_score:0.07539066908814807
训练集f1_score:0.9984695218750549, 测试集f1_score:0.9123771080919295, loss_f1_score:0.08762289190807049
训练集f1_score:1.0, 测试集f1_score:0.9186291073337242, loss_f1_score:0.08137089266627584
训练集f1_score:1.0, 测试集f1_score:0.923763538589959, loss_f1_score:0.07623646141004103
训练集f1_score:1.0, 测试集f1_score:0.921462222217109, loss_f1_score:0.07853777778289095
训练集f1_score:1.0, 测试集f1_score:0.9191128190727075, loss_f1_score:0.08088718092729252
训练集f1_score:1.0, 测试集f1_score:0.9190250636977599, loss_f1_score:0.08097493630224006
训练集f1_score:1.0, 测试集f1_score:0.9170201116959248, loss_f1_score:0.08297988830407521
训练集f1_score:1.0, 测试集f1_score:0.9203767257453489, loss_f1_score:0.0796232742546511
训练集f1_score:0.9930047208776313, 测试集f1_score:0.9130655662892675, loss_f1_score:0.0869344337107325
训练集f1_score:1.0, 测试集f1_score:0.9189223138054201, loss_f1_score:0.08107768619457989
训练集f1_score:1.0, 测试集f1_score:0.9148023758078679, loss_f1_score:0.08519762419213206
训练集f1_score:1.0, 测试集f1_score:0.9179944006960786, loss_f1_score:0.08200559930392137
训练集f1_score:1.0, 测试集f1_score:0.9289351646437458, loss_f1_score:0.07106483535625419
训练集f1_score:0.9949613424783701, 测试集f1_score:0.9168424949905195, loss_f1_score:0.08315750500948049
训练集f1_score:1.0, 测试集f1_score:0.9281459056572339, loss_f1_score:0.07185409434276613
训练集f1_score:0.985057843980505, 测试集f1_score:0.9072691401349147, loss_f1_score:0.09273085986508534
训练集f1_score:1.0, 测试集f1_score:0.9240134358138835, loss_f1_score:0.07598656418611649
训练集f1_score:1.0, 测试集f1_score:0.9193661487057992, loss_f1_score:0.08063385129420075
训练集f1_score:0.9960285051705308, 测试集f1_score:0.9120716883191599, loss_f1_score:0.08792831168084014
训练集f1_score:1.0, 测试集f1_score:0.921146463718955, loss_f1_score:0.07885353628104497
训练集f1_score:0.9998006409358186, 测试集f1_score:0.9182667116480335, loss_f1_score:0.08173328835196647
训练集f1_score:1.0, 测试集f1_score:0.9211347539389867, loss_f1_score:0.0788652460610133
训练集f1_score:1.0, 测试集f1_score:0.9207820427538994, loss_f1_score:0.07921795724610059
训练集f1_score:1.0, 测试集f1_score:0.9257423876791567, loss_f1_score:0.07425761232084327
训练集f1_score:1.0, 测试集f1_score:0.9221586983297824, loss_f1_score:0.07784130167021763
训练集f1_score:1.0, 测试集f1_score:0.9258882305068562, loss_f1_score:0.0741117694931438
训练集f1_score:1.0, 测试集f1_score:0.9212498032441728, loss_f1_score:0.07875019675582717
训练集f1_score:1.0, 测试集f1_score:0.9292049500198453, loss_f1_score:0.07079504998015473
训练集f1_score:1.0, 测试集f1_score:0.9281981188283969, loss_f1_score:0.07180188117160313
训练集f1_score:1.0, 测试集f1_score:0.9204853455880198, loss_f1_score:0.07951465441198025
训练集f1_score:1.0, 测试集f1_score:0.9256299233258835, loss_f1_score:0.07437007667411655
训练集f1_score:1.0, 测试集f1_score:0.925378142730627, loss_f1_score:0.07462185726937298
训练集f1_score:1.0, 测试集f1_score:0.924234872816701, loss_f1_score:0.07576512718329897
训练集f1_score:1.0, 测试集f1_score:0.9251101235622722, loss_f1_score:0.07488987643772782
训练集f1_score:1.0, 测试集f1_score:0.9293946414411237, loss_f1_score:0.07060535855887629
训练集f1_score:0.9382452765248225, 测试集f1_score:0.8942311244378818, loss_f1_score:0.10576887556211823
训练集f1_score:1.0, 测试集f1_score:0.9194809034931137, loss_f1_score:0.08051909650688627
训练集f1_score:0.9891553541338339, 测试集f1_score:0.90941841053602, loss_f1_score:0.09058158946397998
训练集f1_score:1.0, 测试集f1_score:0.9201427057573904, loss_f1_score:0.07985729424260957
训练集f1_score:1.0, 测试集f1_score:0.9198530476378791, loss_f1_score:0.08014695236212088
训练集f1_score:1.0, 测试集f1_score:0.9208104365113973, loss_f1_score:0.07918956348860273
训练集f1_score:0.9993129514194335, 测试集f1_score:0.9179142552260235, loss_f1_score:0.08208574477397645
训练集f1_score:0.9998006409358186, 测试集f1_score:0.9191648350234985, loss_f1_score:0.08083516497650145
训练集f1_score:1.0, 测试集f1_score:0.9190587031090797, loss_f1_score:0.08094129689092033
训练集f1_score:1.0, 测试集f1_score:0.9299693184748484, loss_f1_score:0.07003068152515157
训练集f1_score:1.0, 测试集f1_score:0.9264077945886542, loss_f1_score:0.07359220541134581
训练集f1_score:1.0, 测试集f1_score:0.9291645062607277, loss_f1_score:0.07083549373927234
训练集f1_score:1.0, 测试集f1_score:0.9305299545426143, loss_f1_score:0.06947004545738567
训练集f1_score:1.0, 测试集f1_score:0.9212264145874105, loss_f1_score:0.0787735854125895
训练集f1_score:1.0, 测试集f1_score:0.9226650230252234, loss_f1_score:0.07733497697477665
训练集f1_score:1.0, 测试集f1_score:0.9241207734337085, loss_f1_score:0.07587922656629154
训练集f1_score:1.0, 测试集f1_score:0.9266162462970313, loss_f1_score:0.07338375370296868
训练集f1_score:1.0, 测试集f1_score:0.9289015037853151, loss_f1_score:0.07109849621468489
训练集f1_score:1.0, 测试集f1_score:0.926452470167157, loss_f1_score:0.07354752983284296
训练集f1_score:1.0, 测试集f1_score:0.9230948398574609, loss_f1_score:0.07690516014253912
训练集f1_score:1.0, 测试集f1_score:0.9292092930785151, loss_f1_score:0.07079070692148492
训练集f1_score:1.0, 测试集f1_score:0.927623538125419, loss_f1_score:0.07237646187458102
训练集f1_score:1.0, 测试集f1_score:0.921699473649286, loss_f1_score:0.07830052635071405
训练集f1_score:1.0, 测试集f1_score:0.9201681619180749, loss_f1_score:0.07983183808192507
100%|██████████| 100/100 [59:33<00:00, 35.74s/trial, best loss: 0.06947004545738567]
五折交叉验证:
#5折交叉
def f1_score_eval(preds, valid_df):
labels = valid_df.get_label()
preds = np.argmax(preds.reshape(3, -1), axis = 0)
scores = f1_score(y_true = labels, y_pred = preds, average = 'macro')
return 'f1_score', scores, True
def sub_on_line_lgb(train_, test_, pred, label, cate_cols, split, is_shuffle = True, use_cart = False, get_prob = False):
n_class = 3
train_pred = np.zeros((train_.shape[0], n_class))
test_pred = np.zeros((test_.shape[0], n_class))
n_splits = 5
assert split in ['kf', 'skf'], '{} Not Support this type of split way'.format(split)
if split == 'kf':
folds = KFold(n_splits = n_splits, shuffle = True, random_state = 1024)
kf_way = folds.split(train_[pred])
else:
folds = StratifiedKFold(n_splits = n_splits, shuffle = True, random_state = 1024)
kf_way = folds.split(train_[pred], train_[label])
print('Use {} features ...'.format(len(pred)))
#将参数改为贝叶斯优化后的参数
params = {'learning_rate': 0.05,
'boosting_type': 'gbdt',
'objective': 'multiclass',
'metric': 'None',
'num_leaves': 60,
'feature_fraction':0.86,
'bagging_fraction': 0.73,
'bagging_freq': 5,
'seed': 1,
'bagging_seed': 1,
'feature_fraction_seed': 7,
'min_data_in_leaf': 15,
'num_class': n_class,
'nthread': 8,
'verbose': -1,
'num_boost_round': 1100,
'max_depth': 7,}
for n_fold, (train_idx, valid_idx) in enumerate(kf_way, start = 1):
print('the {} training start ...'.format(n_fold))
train_x, train_y = train_[pred].iloc[train_idx], train_[label].iloc[train_idx]
valid_x, valid_y = train_[pred].iloc[valid_idx], train_[label].iloc[valid_idx]
if use_cart:
dtrain = lgb.Dataset(train_x, label = train_y, categorical_feature = cate_cols)
dvalid = lgb.Dataset(valid_x, label = valid_y, categorical_feature = cate_cols)
else:
dtrain = lgb.Dataset(train_x, label = train_y)
dvaild = lgb.Dataset(valid_x, label = valid_y)
clf = lgb.train(params = params, train_set = dtrain, valid_sets = [dvalid], early_stopping_rounds = 100, verbose_eval = 100, feval = f1_score_eval)
train_pred[valid_idx] = clf.predict(valid_x, num_iteration=clf.best_iteration)
test_pred += clf.predict(test_[pred], num_iteration=clf.best_iteration) / folds.n_splits
print(classification_report(train_[label], np.argmax(train_pred, axis=1), digits=4))
if get_prob:
sub_probs = ['qyxs_prob_{}'.format(q) for q in ['围网', '刺网', '拖网']]
prob_df = pd.DataFrame(test_pred, columns=sub_probs)
prob_df['ID'] = test_['ID'].values
return prob_df
else:
test_['label'] = np.argmax(test_pred, axis=1)
return test_[['ID', 'label']]
use_train = all_df[all_df['label'] != -1]
use_test = all_df[all_df['label'] == -1]
# use_feats = [c for c in use_train.columns if c not in ['ID', 'label']]
use_feats=model_feature
sub = sub_on_line_lgb(use_train, use_test, use_feats, 'label', [], 'kf',is_shuffle=True,use_cart=False,get_prob=False)
输出如下
Use 200 features ...
the 1 training start ...
Training until validation scores don't improve for 100 rounds
[100] valid_0's f1_score: 0.894256
[200] valid_0's f1_score: 0.909942
[300] valid_0's f1_score: 0.913423
[400] valid_0's f1_score: 0.917897
[500] valid_0's f1_score: 0.920616
Early stopping, best iteration is:
[456] valid_0's f1_score: 0.920717
precision recall f1-score support
0 0.2663 0.9772 0.4185 1621
1 0.9672 0.1739 0.2948 1018
2 0.9539 0.1899 0.3167 4361
accuracy 0.3699 7000
macro avg 0.7291 0.4470 0.3433 7000
weighted avg 0.7966 0.3699 0.3371 7000
the 2 training start ...
Training until validation scores don't improve for 100 rounds
[100] valid_0's f1_score: 0.918357
[200] valid_0's f1_score: 0.916436
Early stopping, best iteration is:
[140] valid_0's f1_score: 0.92449
precision recall f1-score support
0 0.3169 0.9562 0.4760 1621
1 0.9531 0.3595 0.5221 1018
2 0.9548 0.3777 0.5412 4361
accuracy 0.5090 7000
macro avg 0.7416 0.5645 0.5131 7000
weighted avg 0.8068 0.5090 0.5234 7000
the 3 training start ...
Training until validation scores don't improve for 100 rounds
[100] valid_0's f1_score: 0.915242
[200] valid_0's f1_score: 0.927189
[300] valid_0's f1_score: 0.930614
Early stopping, best iteration is:
[238] valid_0's f1_score: 0.930614
precision recall f1-score support
0 0.3946 0.9389 0.5557 1621
1 0.9571 0.5255 0.6785 1018
2 0.9574 0.5673 0.7125 4361
accuracy 0.6473 7000
macro avg 0.7697 0.6773 0.6489 7000
weighted avg 0.8270 0.6473 0.6712 7000
the 4 training start ...
Training until validation scores don't improve for 100 rounds
[100] valid_0's f1_score: 0.901683
[200] valid_0's f1_score: 0.912985
[300] valid_0's f1_score: 0.916988
[400] valid_0's f1_score: 0.92147
[500] valid_0's f1_score: 0.921353
Early stopping, best iteration is:
[411] valid_0's f1_score: 0.922153
precision recall f1-score support
0 0.5392 0.9173 0.6792 1621
1 0.9589 0.7112 0.8167 1018
2 0.9555 0.7640 0.8491 4361
accuracy 0.7919 7000
macro avg 0.8179 0.7975 0.7817 7000
weighted avg 0.8596 0.7919 0.8051 7000
the 5 training start ...
Training until validation scores don't improve for 100 rounds
[100] valid_0's f1_score: 0.900975
[200] valid_0's f1_score: 0.908373
[300] valid_0's f1_score: 0.91384
[400] valid_0's f1_score: 0.917567
Early stopping, best iteration is:
[369] valid_0's f1_score: 0.919843
precision recall f1-score support
0 0.8726 0.9001 0.8861 1621
1 0.9569 0.8949 0.9249 1018
2 0.9586 0.9619 0.9603 4361
accuracy 0.9379 7000
macro avg 0.9294 0.9190 0.9238 7000
weighted avg 0.9385 0.9379 0.9380 7000
<ipython-input-4-ad6d0ae907b5>:58: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
test_['label'] = np.argmax(test_pred, axis=1)