Task4---模型建立

模型训练与预测

1、导入所需模块
2、缺失值处理,特正规一会,类别特征转化等
3、训练模型,选择合适的机器学习模型,利用训练集对模型进行训练
4、预测结果

随机森林

随机森林参数介绍:https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html#sklearn.ensemble.RandomForestClassifier
随机森林属于集成学习,准确率较高,能够有效的运行在大数据集上,处理具有高位特征的输入样本,评估特征重要度
sklearn调用的随机森林分类树的预测算法

from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
#导入数据
iris = datasets.load_iris()
feature = iris.feature_names
X = iris.data
y = iris.target
#随机森林
clf = RandomForestClassifier(n_estimators = 200)
train_X, test_X, train_y, test_y = train_test_split(X, y, test_size = 0.1, random_state = 5)
clf.fit(train_X, train_y)
test_pred = clf.predict(test_X)
#特征的重要性查看
print(str(feature) + '\n' + str(clf.feature_importances_))
#F1_score用于模型评价
score = f1_score(test_y, test_pred, average = 'macro')
print("随机森林-macro:",score)

score=f1_score(test_y,test_pred, average='weighted')
print("随机森林-weighted:", score)

[‘sepal length (cm)’, ‘sepal width (cm)’, ‘petal length (cm)’, ‘petal width (cm)’]
[0.10214829 0.02506777 0.42989154 0.4428924 ]
随机森林-macro: 0.818181818181818
随机森林-weighted: 0.8

lightGBM模型

lightGBM模型学习资料:https://mp.weixin.qq.com/s/64xfT9WIgF3yEExpSxyshQ
LightGBM(Light Gradient Boosting Machine)是一个实现GBDT算法的框架,支持高效率的并行训练,并且具有更快的训练速度、更低的内存消耗、更好的准确率、支持分布式可以快速处理海量数据等优点。

 #Lightgbm
import lightgbm as lgb
from sklearn import datasets
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score
import matplotlib.pyplot as plt
#加载数据
iris = datasets.load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size = 0.3)
#转化为DataSet数据格式
train_data = lgb.Dataset(X_train, label = y_train)
validation_data = lgb.Dataset(X_test, label = y_test)
#参数
results = {}
params = {'learning_rate': 0.1, 'lambda_l1': 0.1, 'lambda_l2': 0.9, 'max_depth': 1, 'objective': 'multiclass', 'num_class': 3, 'verbose': -1}
#模型训练
gbm = lgb.train(params, train_data, valid_sets = (validation_data, train_data), valid_names = ('validate','train'), evals_result = results)
#模型预测
y_pred_test = gbm.predict(X_test)
y_pred_data = gbm.predict(X_train)
y_pred_data = [list(x).index(max(x)) for x in y_pred_data]
y_pred_test = [list(x).index(max(x)) for x in y_pred_test]
#模型评估
print(accuracy_score(y_test, y_pred_test))
print('训练集', f1_score(y_train, y_pred_data, average = 'macro'))
print('验证集', f1_score(y_test, y_pred_test, average = 'macro'))
lgb.plot_metric(results)

结果如下

[1]	train's multi_logloss: 0.976938	validate's multi_logloss: 0.981023
[2]	train's multi_logloss: 0.875592	validate's multi_logloss: 0.88226
[3]	train's multi_logloss: 0.788956	validate's multi_logloss: 0.802856
[4]	train's multi_logloss: 0.713724	validate's multi_logloss: 0.734451
[5]	train's multi_logloss: 0.649018	validate's multi_logloss: 0.669279
[6]	train's multi_logloss: 0.591531	validate's multi_logloss: 0.62037
[7]	train's multi_logloss: 0.541379	validate's multi_logloss: 0.573368
[8]	train's multi_logloss: 0.496263	validate's multi_logloss: 0.533024
[9]	train's multi_logloss: 0.456348	validate's multi_logloss: 0.50304
[10]	train's multi_logloss: 0.420516	validate's multi_logloss: 0.466105
[11]	train's multi_logloss: 0.388361	validate's multi_logloss: 0.444444
[12]	train's multi_logloss: 0.359458	validate's multi_logloss: 0.414478
[13]	train's multi_logloss: 0.333213	validate's multi_logloss: 0.393904
[14]	train's multi_logloss: 0.30962	validate's multi_logloss: 0.378558
[15]	train's multi_logloss: 0.28816	validate's multi_logloss: 0.356121
[16]	train's multi_logloss: 0.268664	validate's multi_logloss: 0.343697
[17]	train's multi_logloss: 0.250909	validate's multi_logloss: 0.329402
[18]	train's multi_logloss: 0.234768	validate's multi_logloss: 0.318374
[19]	train's multi_logloss: 0.220133	validate's multi_logloss: 0.309139
[20]	train's multi_logloss: 0.206499	validate's multi_logloss: 0.297277
[21]	train's multi_logloss: 0.194156	validate's multi_logloss: 0.287945
[22]	train's multi_logloss: 0.182832	validate's multi_logloss: 0.284079
[23]	train's multi_logloss: 0.172406	validate's multi_logloss: 0.274423
[24]	train's multi_logloss: 0.162929	validate's multi_logloss: 0.270819
[25]	train's multi_logloss: 0.154074	validate's multi_logloss: 0.266948
[26]	train's multi_logloss: 0.14596	validate's multi_logloss: 0.259182
[27]	train's multi_logloss: 0.138531	validate's multi_logloss: 0.256827
[28]	train's multi_logloss: 0.131556	validate's multi_logloss: 0.254639
[29]	train's multi_logloss: 0.125138	validate's multi_logloss: 0.250006
[30]	train's multi_logloss: 0.119295	validate's multi_logloss: 0.244976
[31]	train's multi_logloss: 0.113977	validate's multi_logloss: 0.241169
[32]	train's multi_logloss: 0.10876	validate's multi_logloss: 0.240898
[33]	train's multi_logloss: 0.103932	validate's multi_logloss: 0.239186
[34]	train's multi_logloss: 0.099617	validate's multi_logloss: 0.236112
[35]	train's multi_logloss: 0.0954041	validate's multi_logloss: 0.236957
[36]	train's multi_logloss: 0.0915208	validate's multi_logloss: 0.237267
[37]	train's multi_logloss: 0.0880111	validate's multi_logloss: 0.236139
[38]	train's multi_logloss: 0.0845806	validate's multi_logloss: 0.236863
[39]	train's multi_logloss: 0.081451	validate's multi_logloss: 0.237791
[40]	train's multi_logloss: 0.0785937	validate's multi_logloss: 0.23897
[41]	train's multi_logloss: 0.0759003	validate's multi_logloss: 0.238387
[42]	train's multi_logloss: 0.0733256	validate's multi_logloss: 0.239295
[43]	train's multi_logloss: 0.0707755	validate's multi_logloss: 0.241507
[44]	train's multi_logloss: 0.0684962	validate's multi_logloss: 0.24261
[45]	train's multi_logloss: 0.0662404	validate's multi_logloss: 0.244052
[46]	train's multi_logloss: 0.0642087	validate's multi_logloss: 0.246097
[47]	train's multi_logloss: 0.0623323	validate's multi_logloss: 0.247787
[48]	train's multi_logloss: 0.0605096	validate's multi_logloss: 0.246881
[49]	train's multi_logloss: 0.0588124	validate's multi_logloss: 0.248425
[50]	train's multi_logloss: 0.0571935	validate's multi_logloss: 0.250581
[51]	train's multi_logloss: 0.0556659	validate's multi_logloss: 0.252133
[52]	train's multi_logloss: 0.0542134	validate's multi_logloss: 0.254299
[53]	train's multi_logloss: 0.0528306	validate's multi_logloss: 0.255876
[54]	train's multi_logloss: 0.0515196	validate's multi_logloss: 0.257431
[55]	train's multi_logloss: 0.0501929	validate's multi_logloss: 0.257479
[56]	train's multi_logloss: 0.0489915	validate's multi_logloss: 0.259634
[57]	train's multi_logloss: 0.0478563	validate's multi_logloss: 0.261256
[58]	train's multi_logloss: 0.0467488	validate's multi_logloss: 0.263452
[59]	train's multi_logloss: 0.0456721	validate's multi_logloss: 0.26314
[60]	train's multi_logloss: 0.0446422	validate's multi_logloss: 0.262897
[61]	train's multi_logloss: 0.0436813	validate's multi_logloss: 0.264889
[62]	train's multi_logloss: 0.042755	validate's multi_logloss: 0.265027
[63]	train's multi_logloss: 0.0418598	validate's multi_logloss: 0.267001
[64]	train's multi_logloss: 0.0410012	validate's multi_logloss: 0.268787
[65]	train's multi_logloss: 0.0401795	validate's multi_logloss: 0.270944
[66]	train's multi_logloss: 0.0393914	validate's multi_logloss: 0.272873
[67]	train's multi_logloss: 0.0386258	validate's multi_logloss: 0.272925
[68]	train's multi_logloss: 0.0379039	validate's multi_logloss: 0.275053
[69]	train's multi_logloss: 0.0371902	validate's multi_logloss: 0.276963
[70]	train's multi_logloss: 0.0365181	validate's multi_logloss: 0.278877
[71]	train's multi_logloss: 0.0358715	validate's multi_logloss: 0.280588
[72]	train's multi_logloss: 0.0352353	validate's multi_logloss: 0.282603
[73]	train's multi_logloss: 0.0346372	validate's multi_logloss: 0.284594
[74]	train's multi_logloss: 0.0340456	validate's multi_logloss: 0.286232
[75]	train's multi_logloss: 0.0334864	validate's multi_logloss: 0.288148
[76]	train's multi_logloss: 0.0329445	validate's multi_logloss: 0.289775
[77]	train's multi_logloss: 0.0324157	validate's multi_logloss: 0.291646
[78]	train's multi_logloss: 0.031911	validate's multi_logloss: 0.29161
[79]	train's multi_logloss: 0.0314163	validate's multi_logloss: 0.293302
[80]	train's multi_logloss: 0.0309399	validate's multi_logloss: 0.294895
[81]	train's multi_logloss: 0.0304758	validate's multi_logloss: 0.29668
[82]	train's multi_logloss: 0.0300298	validate's multi_logloss: 0.298222
[83]	train's multi_logloss: 0.0295942	validate's multi_logloss: 0.299954
[84]	train's multi_logloss: 0.0291753	validate's multi_logloss: 0.300086
[85]	train's multi_logloss: 0.0287618	validate's multi_logloss: 0.301625
[86]	train's multi_logloss: 0.0283652	validate's multi_logloss: 0.303182
[87]	train's multi_logloss: 0.0279776	validate's multi_logloss: 0.304707
[88]	train's multi_logloss: 0.0276011	validate's multi_logloss: 0.305106
[89]	train's multi_logloss: 0.0272367	validate's multi_logloss: 0.306678
[90]	train's multi_logloss: 0.026883	validate's multi_logloss: 0.30822
[91]	train's multi_logloss: 0.0265358	validate's multi_logloss: 0.309666
[92]	train's multi_logloss: 0.0262009	validate's multi_logloss: 0.309765
[93]	train's multi_logloss: 0.0258715	validate's multi_logloss: 0.311171
[94]	train's multi_logloss: 0.0255548	validate's multi_logloss: 0.312661
[95]	train's multi_logloss: 0.0252423	validate's multi_logloss: 0.312928
[96]	train's multi_logloss: 0.0249416	validate's multi_logloss: 0.314297
[97]	train's multi_logloss: 0.0246452	validate's multi_logloss: 0.315116
[98]	train's multi_logloss: 0.0243606	validate's multi_logloss: 0.31647
[99]	train's multi_logloss: 0.0240774	validate's multi_logloss: 0.317843
[100]	train's multi_logloss: 0.0238036	validate's multi_logloss: 0.318235
0.9333333333333333
训练集 1.0
验证集 0.9332591768631814
[1]	train's multi_logloss: 0.981573	validate's multi_logloss: 0.985608
[2]	train's multi_logloss: 0.883927	validate's multi_logloss: 0.89032
[3]	train's multi_logloss: 0.800254	validate's multi_logloss: 0.811111
[4]	train's multi_logloss: 0.727456	validate's multi_logloss: 0.747084
[5]	train's multi_logloss: 0.664656	validate's multi_logloss: 0.683869
[6]	train's multi_logloss: 0.60884	validate's multi_logloss: 0.636064
[7]	train's multi_logloss: 0.560072	validate's multi_logloss: 0.590141
[8]	train's multi_logloss: 0.516133	validate's multi_logloss: 0.550807
[9]	train's multi_logloss: 0.477245	validate's multi_logloss: 0.517379
[10]	train's multi_logloss: 0.442304	validate's multi_logloss: 0.485167
[11]	train's multi_logloss: 0.410901	validate's multi_logloss: 0.463524
[12]	train's multi_logloss: 0.382674	validate's multi_logloss: 0.435015
[13]	train's multi_logloss: 0.357052	validate's multi_logloss: 0.414516
[14]	train's multi_logloss: 0.334035	validate's multi_logloss: 0.39999
[15]	train's multi_logloss: 0.31286	validate's multi_logloss: 0.377932
[16]	train's multi_logloss: 0.293857	validate's multi_logloss: 0.365563
[17]	train's multi_logloss: 0.276418	validate's multi_logloss: 0.350061
[18]	train's multi_logloss: 0.260565	validate's multi_logloss: 0.337537
[19]	train's multi_logloss: 0.246165	validate's multi_logloss: 0.327766
[20]	train's multi_logloss: 0.2328	validate's multi_logloss: 0.316106
[21]	train's multi_logloss: 0.220709	validate's multi_logloss: 0.308427
[22]	train's multi_logloss: 0.209576	validate's multi_logloss: 0.298055
[23]	train's multi_logloss: 0.199358	validate's multi_logloss: 0.292903
[24]	train's multi_logloss: 0.18999	validate's multi_logloss: 0.287517
[25]	train's multi_logloss: 0.1813	validate's multi_logloss: 0.279557
[26]	train's multi_logloss: 0.173332	validate's multi_logloss: 0.275507
[27]	train's multi_logloss: 0.165919	validate's multi_logloss: 0.271781
[28]	train's multi_logloss: 0.159137	validate's multi_logloss: 0.265719
[29]	train's multi_logloss: 0.152858	validate's multi_logloss: 0.264156
[30]	train's multi_logloss: 0.147033	validate's multi_logloss: 0.26041
[31]	train's multi_logloss: 0.141605	validate's multi_logloss: 0.256151
[32]	train's multi_logloss: 0.136616	validate's multi_logloss: 0.255597
[33]	train's multi_logloss: 0.131889	validate's multi_logloss: 0.252857
[34]	train's multi_logloss: 0.127604	validate's multi_logloss: 0.24959
[35]	train's multi_logloss: 0.123542	validate's multi_logloss: 0.249568
[36]	train's multi_logloss: 0.119815	validate's multi_logloss: 0.246283
[37]	train's multi_logloss: 0.116406	validate's multi_logloss: 0.243429
[38]	train's multi_logloss: 0.113095	validate's multi_logloss: 0.243128
[39]	train's multi_logloss: 0.1101	validate's multi_logloss: 0.241747
[40]	train's multi_logloss: 0.107342	validate's multi_logloss: 0.241385
[41]	train's multi_logloss: 0.104745	validate's multi_logloss: 0.242009
[42]	train's multi_logloss: 0.102293	validate's multi_logloss: 0.239914
[43]	train's multi_logloss: 0.0998498	validate's multi_logloss: 0.239593
[44]	train's multi_logloss: 0.0975656	validate's multi_logloss: 0.239634
[45]	train's multi_logloss: 0.095428	validate's multi_logloss: 0.23973
[46]	train's multi_logloss: 0.09341	validate's multi_logloss: 0.239593
[47]	train's multi_logloss: 0.0915987	validate's multi_logloss: 0.238425
[48]	train's multi_logloss: 0.0899355	validate's multi_logloss: 0.238867
[49]	train's multi_logloss: 0.0883439	validate's multi_logloss: 0.239669
[50]	train's multi_logloss: 0.0868249	validate's multi_logloss: 0.238542
[51]	train's multi_logloss: 0.0853989	validate's multi_logloss: 0.239364
[52]	train's multi_logloss: 0.0840528	validate's multi_logloss: 0.239413
[53]	train's multi_logloss: 0.0827655	validate's multi_logloss: 0.238817
[54]	train's multi_logloss: 0.0815579	validate's multi_logloss: 0.239628
[55]	train's multi_logloss: 0.0804065	validate's multi_logloss: 0.23992
[56]	train's multi_logloss: 0.0793187	validate's multi_logloss: 0.240284
[57]	train's multi_logloss: 0.0782935	validate's multi_logloss: 0.24064
[58]	train's multi_logloss: 0.0773119	validate's multi_logloss: 0.241139
[59]	train's multi_logloss: 0.0763733	validate's multi_logloss: 0.24082
[60]	train's multi_logloss: 0.0754934	validate's multi_logloss: 0.241047
[61]	train's multi_logloss: 0.0746551	validate's multi_logloss: 0.241338
[62]	train's multi_logloss: 0.0738464	validate's multi_logloss: 0.241815
[63]	train's multi_logloss: 0.073084	validate's multi_logloss: 0.242101
[64]	train's multi_logloss: 0.0723531	validate's multi_logloss: 0.24226
[65]	train's multi_logloss: 0.0716587	validate's multi_logloss: 0.242809
[66]	train's multi_logloss: 0.0709938	validate's multi_logloss: 0.243073
[67]	train's multi_logloss: 0.0703601	validate's multi_logloss: 0.243336
[68]	train's multi_logloss: 0.0697568	validate's multi_logloss: 0.243481
[69]	train's multi_logloss: 0.0691771	validate's multi_logloss: 0.243966
[70]	train's multi_logloss: 0.0686266	validate's multi_logloss: 0.24434
[71]	train's multi_logloss: 0.0680923	validate's multi_logloss: 0.24448
[72]	train's multi_logloss: 0.0675908	validate's multi_logloss: 0.244918
[73]	train's multi_logloss: 0.0671031	validate's multi_logloss: 0.245043
[74]	train's multi_logloss: 0.066642	validate's multi_logloss: 0.24534
[75]	train's multi_logloss: 0.0661947	validate's multi_logloss: 0.24546
[76]	train's multi_logloss: 0.0657717	validate's multi_logloss: 0.245846
[77]	train's multi_logloss: 0.0653627	validate's multi_logloss: 0.246028
[78]	train's multi_logloss: 0.0649723	validate's multi_logloss: 0.246223
[79]	train's multi_logloss: 0.0645964	validate's multi_logloss: 0.246321
[80]	train's multi_logloss: 0.0642376	validate's multi_logloss: 0.246508
[81]	train's multi_logloss: 0.0638923	validate's multi_logloss: 0.246662
[82]	train's multi_logloss: 0.0635608	validate's multi_logloss: 0.246979
[83]	train's multi_logloss: 0.0632425	validate's multi_logloss: 0.247059
[84]	train's multi_logloss: 0.0629365	validate's multi_logloss: 0.247227
[85]	train's multi_logloss: 0.0626434	validate's multi_logloss: 0.247345
[86]	train's multi_logloss: 0.0623601	validate's multi_logloss: 0.247621
[87]	train's multi_logloss: 0.0620874	validate's multi_logloss: 0.247688
[88]	train's multi_logloss: 0.0618245	validate's multi_logloss: 0.247881
[89]	train's multi_logloss: 0.0615721	validate's multi_logloss: 0.247934
[90]	train's multi_logloss: 0.0613283	validate's multi_logloss: 0.248175
[91]	train's multi_logloss: 0.0610936	validate's multi_logloss: 0.248227
[92]	train's multi_logloss: 0.0608669	validate's multi_logloss: 0.248395
[93]	train's multi_logloss: 0.0606489	validate's multi_logloss: 0.248442
[94]	train's multi_logloss: 0.0604381	validate's multi_logloss: 0.248563
[95]	train's multi_logloss: 0.060235	validate's multi_logloss: 0.248645
[96]	train's multi_logloss: 0.0600386	validate's multi_logloss: 0.248843
[97]	train's multi_logloss: 0.0598487	validate's multi_logloss: 0.24889
[98]	train's multi_logloss: 0.0596653	validate's multi_logloss: 0.249037
[99]	train's multi_logloss: 0.0594871	validate's multi_logloss: 0.249088
[100]	train's multi_logloss: 0.0593146	validate's multi_logloss: 0.249275
0.9333333333333333
训练集 1.0
验证集 0.9332591768631814

验证集的损失大于训练集的损失,所以模型出现过拟合:
处理过拟合的方法: 使用 lambda_l1, lambda_l2 和 min_gain_to_split 来使用正则;使用较小的 max_bin‘,使用较小的 max_bin等
在这里插入图片描述
此处将lambda_l1 = 0.9

params = {'learning_rate': 0.1, 'lambda_l1': 0.9, 'lambda_l2': 0.9, 'max_depth': 1, 'objective': 'multiclass', 'num_class': 3, 'verbose': -1}
#模型训练
gbm = lgb.train(params, train_data, valid_sets = (validation_data, train_data), valid_names = ('validate','train'), evals_result = results)
#模型预测
y_pred_test = gbm.predict(X_test)
y_pred_data = gbm.predict(X_train)
y_pred_data = [list(x).index(max(x)) for x in y_pred_data]
y_pred_test = [list(x).index(max(x)) for x in y_pred_test]
#模型评估
print(accuracy_score(y_test, y_pred_test))
print('训练集', f1_score(y_train, y_pred_data, average = 'macro'))
print('验证集', f1_score(y_test, y_pred_test, average = 'macro'))
lgb.plot_metric(results)
#绘制特征重要度
lgb.plot_importance(gbm, importance_type = 'split' )
plt.show()

在这里插入图片描述
在这里插入图片描述

Xgboost模型

xgboost模型学习网站 https://mp.weixin.qq.com/s/AAKPSIHk1iUqCeUibrORqQ

#xgboost
from sklearn.datasets import load_iris
import xgboost as xgb
from xgboost import plot_importance
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score   # 准确率
# 加载样本数据集
iris = load_iris()
X,y = iris.data,iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234565) # 数据集分割
# 算法参数
params = {
    'booster': 'gbtree',
    'objective': 'multi:softmax',
    'eval_metric':'mlogloss',
    'num_class': 3,
    'gamma': 0.1,
    'max_depth': 6,
    'lambda': 2,
    'subsample': 0.7,
    'colsample_bytree': 0.75,
    'min_child_weight': 3,
    'eta': 0.1,
    'seed': 1,
    'nthread': 4,
}

# plst = params.items()

train_data = xgb.DMatrix(X_train, y_train) # 生成数据集格式
num_rounds = 500
model = xgb.train(params, train_data) # xgboost模型训练

# 对测试集进行预测
dtest = xgb.DMatrix(X_test)
y_pred = model.predict(dtest)

# 计算准确率
F1_score = f1_score(y_test,y_pred,average='macro')
print("F1_score: %.2f%%" % (F1_score*100.0))

# 显示重要特征
plot_importance(model)
plt.show()

F1_score: 95.56%
在这里插入图片描述

交叉验证

交叉验证:将原始数据进行分组,一部分作为训练集,一部分作为验证集,主要有简单交叉验证,k折交叉验证,留一法交叉验证和留P法交叉验证:代码如下:

#交叉验证
#简单交叉验证
from sklearn.model_selection import train_test_split
from sklearn import datasets
#数据集导入
iris = datasets.load_iris()
feature = iris.feature_names
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.4, random_state = 0)

#K折交叉验证
from sklearn.model_selection import KFold
fols = KFold(n_splits = 10, shuffle = True)

#留一法交叉验证
from sklearn.model_selection import LeaveOneOut
loo = LeaveOneOut()

#留P法交叉验证
from sklearn.model_selection import LeavePOut
lpo = LeavePOut(p = 5)

参数调优

1、网格搜索:穷举搜索;
2、学习曲线:在训练集大小不同时通过绘制模型训练集与交叉验证上的准确率观察模型在新数据上的表现,判断模型的方差
网格搜索比较耗时,所以下述代码没有运行完成

#以Xgboost为例,该网格搜索代码示例如下
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import GridSearchCV
cancer = load_breast_cancer()
x = cancer.data[:50]
y = cancer.target[:50]
train_x, valid_x, train_y, valid_y = train_test_split(x, y, test_size=0.333, random_state=0) # 分训
# 这里不需要Dmatrix
parameters = {
 'max_depth': [5, 10, 15, 20, 25],
 'learning_rate': [0.01, 0.02, 0.05, 0.1, 0.15],
 'n_estimators': [50, 100, 200, 300, 500],
 'min_child_weight': [0, 2, 5, 10, 20],
 'max_delta_step': [0, 0.2, 0.6, 1, 2],
 'subsample': [0.6, 0.7, 0.8, 0.85, 0.95],
 'colsample_bytree': [0.5, 0.6, 0.7, 0.8, 0.9],
 'reg_alpha': [0, 0.25, 0.5, 0.75, 1],
 'reg_lambda': [0.2, 0.4, 0.6, 0.8, 1],
 'scale_pos_weight': [0.2, 0.4, 0.6, 0.8, 1]
}
xlf = xgb.XGBClassifier(max_depth=10,
 learning_rate=0.01,
n_estimators=2000,
silent=True,
 objective='binary:logistic',
nthread=-1,
 gamma=0,
min_child_weight=1,
max_delta_step=0,
subsample=0.85,
colsample_bytree=0.7,
colsample_bylevel=1,
reg_alpha=0,
reg_lambda=1,
scale_pos_weight=1,
seed=1440,
missing=None)
# 有了gridsearch我们便不需要fit函数
gsearch = GridSearchCV(xlf, param_grid=parameters, scoring='accuracy', cv=3)
gsearch.fit(train_x, train_y)
print("Best score: %0.3f" % gsearch.best_score_)
print("Best parameters set:")
best_parameters = gsearch.best_estimator_.get_params()
for param_name in sorted(parameters.keys()):
 print("\t%s: %r" % (param_name, best_parameters[param_name]))

智慧海洋代码

导包时候注意:from config import config→import config

```python
import config
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report, f1_score
from sklearn.model_selection import StratifiedKFold, KFold, train_test_split
import lightgbm as lgb
import os
import warnings
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
all_df = pd.read_csv(r"C:\Users\李\Desktop\datawheal\data\group_df.csv", index_col = 0)
use_train = all_df[all_df['label'] != -1]
use_test = all_df[all_df['label'] == -1] #label为-1时是测试集
use_feats = [c for c in use_train.columns if c not in ['ID', 'label']]
X_train, X_verify, y_train, y_verify = train_test_split(use_train[use_feats], use_train['label'], test_size = 0.3, random_state = 0)

特征重要度选择,

#特征重要度选择
selectFeatures = 200#控制特征数
earlyStopping = 100 #控制早停
select_num_boost_round = 1000#特征选择训练轮次
#设置基础参数
selfParam = { 'learning_rate':0.01, # 学习率
'boosting':'dart', # 算法类型, gbdt,dart
 'objective':'multiclass', # 多分类
 'metric':'None',
 'num_leaves':32, # 
 'feature_fraction':0.7, # 训练特征比例
 'bagging_fraction':0.8, # 训练样本比例 
 'min_data_in_leaf':30, # 叶子最小样本
 'num_class': 3,
 'max_depth':6, # 树的最大深度
 
 'num_threads':8,#LightGBM 的线程数
 'min_data_in_bin':30, # 单箱数据量
 'max_bin':256, # 最大分箱数 
 'is_unbalance':True, # 非平衡样本
 'train_metric':True,
 'verbose':-1,
}

特征选择结果如下

select forward 200 features:[('pos_neq_zero_speed_q_40', 1783), 
('lat_lon_countvec_1_x', 1771), ('rank2_mode_lat', 1737), ('pos_neq_zero_speed_median', 1379), ('pos_neq_zero_speed_q_60', 1369), 
('lat_lon_tfidf_0_x', 1251), ('pos_neq_zero_speed_q_80', 1194), 
('sample_tfidf_0_x', 1168), ('w2v_9_mean', 1134), ('lat_lon_tfidf_11_x', 963), 
('rank3_mode_lat', 946), ('w2v_5_mean', 900), ('w2v_16_mean', 874), 
('pos_neq_zero_speed_q_30', 866), ('w2v_12_mean', 862), ('pos_neq_zero_speed_q_70',
 856), ('lat_lon_tfidf_9_x', 787), ('grad_tfidf_7_x', 772), 
 ('pos_neq_zero_speed_q_90', 746), ('rank3_mode_cnt', 733), ('grad_tfidf_12_x', 
 729), ('w2v_4_mean', 697), ('sample_tfidf_14_x', 695), ('lat_lon_tfidf_4_x', 693), 
 ('lat_min', 683), ('w2v_23_mean', 647), ('rank2_mode_lon', 631), ('w2v_26_mean', 
 626), ('rank1_mode_lon', 620), ('grad_tfidf_15_x', 607), 
 ('speed_neq_zero_speed_q_90', 603), ('grad_tfidf_5_x', 572), 
 ('lat_lon_countvec_22_x', 571), ('lat_lon_countvec_1_y', 565), ('w2v_13_mean', 
 557), ('w2v_27_mean', 550), ('grad_tfidf_2_x', 507), ('lat_lon_tfidf_20_x', 503), 
 ('lat_lon_countvec_0_x', 499), ('lat_lon_countvec_18_x', 490), 
 ('sample_tfidf_21_x', 488), ('grad_tfidf_14_x', 484), ('lat_lon_countvec_27_x', 
 470), ('w2v_22_mean', 466), ('lat_lon_tfidf_1_x', 461), ('direction_nunique', 
 460), ('lon_max', 457), ('w2v_15_mean', 441), ('grad_tfidf_23_x', 431), 
 ('w2v_19_mean', 429), ('w2v_11_mean', 428), ('lat_lon_tfidf_29_x', 420), 
 ('pos_neq_zero_lon_q_10', 417), ('w2v_3_mean', 411), ('lat_lon_tfidf_0_y', 407), 
 ('sample_tfidf_29_x', 406), ('anchor_cnt', 404), ('grad_tfidf_8_x', 397), 
 ('sample_tfidf_10_x', 397), ('sample_tfidf_12_x', 385), ('w2v_28_mean', 384), 
 ('grad_tfidf_13_x', 381), ('direction_q_90', 380), ('speed_neq_zero_lon_min', 374), ('w2v_25_mean', 371), ('anchor_ratio', 367), ('lat_lon_tfidf_16_x', 367), 
 ('rank1_mode_lat', 365), ('w2v_18_mean', 365), ('sample_tfidf_23_x', 364), 
 ('lon_min', 354), ('grad_tfidf_0_x', 351), ('pos_neq_zero_lat_q_90', 341), 
 ('w2v_20_mean', 341), ('sample_tfidf_4_x', 334), ('lat_lon_tfidf_23_x', 332), 
 ('sample_tfidf_0_y', 328), ('pos_neq_zero_direction_q_90', 326), 
 ('speed_neq_zero_direction_nunique', 326), ('sample_tfidf_19_x', 323), 
 ('lat_lon_countvec_9_x', 319), ('pos_neq_zero_lon_q_90', 314), ('w2v_8_mean', 
 312), ('grad_tfidf_3_x', 309), ('lon_median', 305), ('pos_neq_zero_speed_q_20', 
 304), ('lat_lon_countvec_4_x', 304), ('lat_mean', 301), ('speed_neq_zero_lon_max',
  301), ('lat_lon_tfidf_14_x', 301), ('speed_neq_zero_lat_min', 300), 
  ('lat_lon_countvec_5_x', 296), ('speed_neq_zero_speed_q_80', 294), 
  ('grad_tfidf_16_x', 293), ('rank3_mode_lon', 292), ('lat_lon_tfidf_18_x', 291), 
  ('w2v_7_mean', 290), ('grad_tfidf_6_x', 285), ('grad_tfidf_20_x', 283), 
  ('grad_tfidf_18_x', 282), ('w2v_0_mean', 280), ('grad_tfidf_21_x', 279), 
  ('grad_tfidf_22_x', 273), ('sample_tfidf_24_x', 273), ('speed_q_90', 271), 
  ('w2v_2_mean', 271), ('lat_max', 264), ('sample_tfidf_9_x', 264), 
('grad_tfidf_11_x', 262), ('lon_q_20', 260), ('rank1_mode_cnt', 258), 
('speed_max', 256), ('lat_lon_tfidf_12_x', 251), ('pos_neq_zero_lon_q_20', 248), 
('lat_lon_tfidf_28_x', 242), ('speed_neq_zero_direction_q_60', 241), 
('sample_tfidf_11_x', 241), ('w2v_17_mean', 241), ('sample_tfidf_13_x', 238), 
('w2v_14_mean', 236), ('lat_nunique', 235), 
('grad_tfidf_4_x', 234), ('w2v_21_mean', 234), ('sample_tfidf_5_x', 231), ('lat_lon_tfidf_9_y', 225), ('speed_neq_zero_lat_q_90', 222), ('direction_median', 221), ('sample_tfidf_17_x', 220), ('sample_tfidf_14_y', 216),
('lat_lon_tfidf_21_x', 215), ('lon_q_10', 214), ('lat_lon_tfidf_22_x', 214), 
('grad_tfidf_26_x', 213), ('grad_tfidf_7_y', 213), ('w2v_29_mean', 212), 
('pos_neq_zero_lat_q_80', 210), ('cnt', 209), ('lat_lon_tfidf_4_y', 208), 
('direction_q_60', 204), ('sample_tfidf_18_x', 203), ('lat_lon_tfidf_11_y', 203), 
('pos_neq_zero_lat_min', 202), ('pos_neq_zero_speed_mean', 201), 
('speed_neq_zero_lat_q_70', 200), ('grad_tfidf_12_y', 198), ('sample_tfidf_20_x', 
197), ('w2v_1_mean', 194), ('speed_neq_zero_lat_q_40', 193), 
('pos_neq_zero_speed_max', 192), ('grad_tfidf_27_x', 192), ('grad_tfidf_15_y', 
191), ('lat_lon_tfidf_19_x', 189), ('lat_median', 187), ('lat_lon_tfidf_15_x', 
187), ('lat_q_20', 186), ('lat_q_70', 186), ('lon_q_70', 185), ('w2v_24_mean', 
184), ('pos_neq_zero_lat_q_40', 183), ('grad_tfidf_25_x', 181), ('w2v_10_mean', 
181), ('lon_mean', 180), ('sample_tfidf_27_x', 180), ('w2v_6_mean', 180), 
('lat_lon_tfidf_24_x', 178), ('lat_lon_countvec_12_x', 178), 
('pos_neq_zero_lat_mean', 177), ('speed_neq_zero_speed_q_70', 174), 
('speed_neq_zero_direction_q_80', 172), ('rank2_mode_cnt', 172), 
('speed_neq_zero_lat_nunique', 171), ('lat_lon_tfidf_2_x', 171), 
('sample_tfidf_25_x', 170), ('lat_lon_tfidf_5_x', 169), ('lat_lon_countvec_26_x', 
167), ('grad_tfidf_9_x', 166), ('lat_lon_countvec_28_x', 163), 
('lat_lon_countvec_22_y', 163), ('sample_tfidf_1_x', 162), 
('pos_neq_zero_direction_nunique', 161), ('pos_neq_zero_speed_q_10', 157), 
('sample_tfidf_16_x', 155), ('speed_neq_zero_direction_q_90', 154), 
('grad_tfidf_14_y', 153), ('lat_lon_tfidf_7_x', 151), 
('pos_neq_zero_direction_q_80', 149), ('lat_q_80', 148), ('grad_tfidf_23_y', 148), 
('lat_lon_countvec_11_x', 147), ('sample_tfidf_22_x', 146), 
('speed_neq_zero_lat_max', 144), ('sample_tfidf_15_x', 144), ('grad_tfidf_2_y', 
144), ('pos_neq_zero_lat_q_10', 142), ('lat_lon_tfidf_1_y', 142), 
('lat_lon_countvec_16_x', 141), ('grad_tfidf_13_y', 138), ('lat_lon_countvec_29_x', 
136), ('lat_lon_tfidf_29_y', 136), ('grad_tfidf_5_y', 136)]

利用贝叶斯优化超参数:

#model_feature是选择的超参数
model_feature = [k[0] for k in sort_feature_importance[:selectFeatures]]
#超参数优化
#参数空间
spaceParam = {'boosting': hp.choice('boosting',['gbdt','dart']),
 'learning_rate':hp.loguniform('learning_rate', np.log(0.01), np.log(0.05)),
 'num_leaves': hp.quniform('num_leaves', 3, 66, 3), 
 'feature_fraction': hp.uniform('feature_fraction', 0.7,1),
 'min_data_in_leaf': hp.quniform('min_data_in_leaf', 10, 50,5), 
 'num_boost_round':hp.quniform('num_boost_round',500,2000,100), 
 'bagging_fraction':hp.uniform('bagging_fraction',0.6,1) 
}
#最优参数编译
def getParam(param):
    for k in ['num_leaves', 'min_data_in_leaf', 'num_boost_round']:
        param[k] = int(float(param[k]))
    for k in ['learning_rate', 'feature_fraction', 'bagging_fraction']:
        param[k] = float(param[k])
    if param['boosting'] == 0:
        param['boosting'] = 'gbdt'
    elif param['boosting'] == 1:
        param['boosting'] = 'dart'
#添加固定参数
    param['objective'] = 'multiclass'
    param['max_depth'] = 7
    param['num_threads'] = 8
    param['is_unbalance'] = True
    param['metric'] = 'None'
    param['train_metric'] = True
    param['verbose'] = -1
    param['bagging_freq'] = 5
    param['num_class'] = 3
    param['feature_pre_filter'] = False
    return param
    #目标函数
def f1_score_eval(preds, valid_df):
    labels = valid_df.get_label()
    preds = np.argmax(preds.reshape(3, -1), axis = 0)
    scores = f1_score(y_true = labels, y_pred = preds, average = 'macro')
    return 'f1_score', scores, True
def lossFun(param):
    param = getParam(param)
    m = lgb.train(param, train_set = train_data, num_boost_round = param['num_boost_round'], valid_sets = [train_data, valid_data]
                  , valid_names = ['train', 'valid'], feature_name = features, feval = f1_score_eval, early_stopping_rounds = earlyStopping,
                  verbose_eval = False, keep_training_booster = True)
    train_f1_score = m.best_score['train']['f1_score']
    valid_f1_score = m.best_score['valid']['f1_score']
    loss_f1_score = 1 - valid_f1_score
    print(('训练集f1_score:{}, 测试集f1_score:{}, loss_f1_score:{}'.format(train_f1_score, valid_f1_score, loss_f1_score)))
               
    return {'loss': loss_f1_score, 'params': param, 'status': STATUS_OK}
         
features = model_feature
train_data = lgb.Dataset(data = X_train[model_feature], label = y_train, feature_name = features)
valid_data = lgb.Dataset(data = X_verify[features], label = y_verify, feature_name = features, reference=train_data)
#搜索最优参数
best_param = fmin(fn = lossFun, space = spaceParam, algo = tpe.suggest, max_evals = 100, trials = Trials())
best_param = getParam(best_param)
print('Search best param:', best_param)
  warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))

E:\anacoda\lib\site-packages\lightgbm\callback.py:186: UserWarning: Early stopping is not available in dart mode
  warnings.warn('Early stopping is not available in dart mode')

训练集f1_score:1.0, 测试集f1_score:0.9285309559097713, loss_f1_score:0.0714690440902287
训练集f1_score:0.9963233527198608, 测试集f1_score:0.9095083779923133, loss_f1_score:0.09049162200768668
训练集f1_score:1.0, 测试集f1_score:0.9257391891204496, loss_f1_score:0.07426081087955039 
训练集f1_score:1.0, 测试集f1_score:0.9198350308928772, loss_f1_score:0.08016496910712279 
训练集f1_score:0.9992236245819818, 测试集f1_score:0.9167964902969308, loss_f1_score:0.08320350970306922
训练集f1_score:1.0, 测试集f1_score:0.9228775314361938, loss_f1_score:0.07712246856380622 
训练集f1_score:0.9815257139740128, 测试集f1_score:0.9059852356592902, loss_f1_score:0.09401476434070977
训练集f1_score:0.9995123993009063, 测试集f1_score:0.9201072375043884, loss_f1_score:0.07989276249561161
训练集f1_score:1.0, 测试集f1_score:0.917180592411146, loss_f1_score:0.08281940758885398  
训练集f1_score:1.0, 测试集f1_score:0.9230556252348961, loss_f1_score:0.07694437476510385 
训练集f1_score:1.0, 测试集f1_score:0.9177203389359856, loss_f1_score:0.08227966106401441  
训练集f1_score:0.998757147314354, 测试集f1_score:0.9104583820708552, loss_f1_score:0.08954161792914483
训练集f1_score:1.0, 测试集f1_score:0.914298851135159, loss_f1_score:0.08570114886484104   
训练集f1_score:1.0, 测试集f1_score:0.9231479325133111, loss_f1_score:0.07685206748668894  
训练集f1_score:1.0, 测试集f1_score:0.9224745430635202, loss_f1_score:0.07752545693647983  
训练集f1_score:1.0, 测试集f1_score:0.9174189426487299, loss_f1_score:0.08258105735127008  
训练集f1_score:0.9925304441826487, 测试集f1_score:0.906781252235943, loss_f1_score:0.09321874776405703
训练集f1_score:0.9631339775372375, 测试集f1_score:0.9066989618735376, loss_f1_score:0.0933010381264624
训练集f1_score:1.0, 测试集f1_score:0.9250232632818385, loss_f1_score:0.07497673671816152  
训练集f1_score:1.0, 测试集f1_score:0.9203115079005767, loss_f1_score:0.0796884920994233 
训练集f1_score:1.0, 测试集f1_score:0.9248984394690476, loss_f1_score:0.0751015605309524 
训练集f1_score:1.0, 测试集f1_score:0.9215913930032361, loss_f1_score:0.07840860699676389
训练集f1_score:0.9440002150564252, 测试集f1_score:0.8935415971895097, loss_f1_score:0.10645840281049035
训练集f1_score:1.0, 测试集f1_score:0.926645936839296, loss_f1_score:0.07335406316070403 
训练集f1_score:1.0, 测试集f1_score:0.9226316974274017, loss_f1_score:0.0773683025725983 
训练集f1_score:1.0, 测试集f1_score:0.9242701223911126, loss_f1_score:0.07572987760888739  
训练集f1_score:1.0, 测试集f1_score:0.9198254231957517, loss_f1_score:0.08017457680424833  
训练集f1_score:1.0, 测试集f1_score:0.9256990122822063, loss_f1_score:0.0743009877177937   
训练集f1_score:1.0, 测试集f1_score:0.9194102895600668, loss_f1_score:0.08058971043993324  
训练集f1_score:0.9923324494323645, 测试集f1_score:0.9023188996860295, loss_f1_score:0.09768110031397048
训练集f1_score:1.0, 测试集f1_score:0.9207667553153351, loss_f1_score:0.07923324468466486  
训练集f1_score:1.0, 测试集f1_score:0.9235684862240877, loss_f1_score:0.07643151377591229  
训练集f1_score:1.0, 测试集f1_score:0.9159299981259025, loss_f1_score:0.08407000187409752  
训练集f1_score:1.0, 测试集f1_score:0.9201677556683528, loss_f1_score:0.07983224433164715  
训练集f1_score:1.0, 测试集f1_score:0.9214644211379112, loss_f1_score:0.07853557886208884
训练集f1_score:1.0, 测试集f1_score:0.9209037282718199, loss_f1_score:0.07909627172818012  
训练集f1_score:1.0, 测试集f1_score:0.9126368594667446, loss_f1_score:0.08736314053325545  
训练集f1_score:0.9752612340062626, 测试集f1_score:0.9007932634486341, loss_f1_score:0.09920673655136591
训练集f1_score:1.0, 测试集f1_score:0.9156230890670725, loss_f1_score:0.08437691093292754
训练集f1_score:1.0, 测试集f1_score:0.9205297016223192, loss_f1_score:0.07947029837768083
训练集f1_score:1.0, 测试集f1_score:0.9246093309118519, loss_f1_score:0.07539066908814807
训练集f1_score:0.9984695218750549, 测试集f1_score:0.9123771080919295, loss_f1_score:0.08762289190807049
训练集f1_score:1.0, 测试集f1_score:0.9186291073337242, loss_f1_score:0.08137089266627584  
训练集f1_score:1.0, 测试集f1_score:0.923763538589959, loss_f1_score:0.07623646141004103   
训练集f1_score:1.0, 测试集f1_score:0.921462222217109, loss_f1_score:0.07853777778289095   
训练集f1_score:1.0, 测试集f1_score:0.9191128190727075, loss_f1_score:0.08088718092729252  
训练集f1_score:1.0, 测试集f1_score:0.9190250636977599, loss_f1_score:0.08097493630224006  
训练集f1_score:1.0, 测试集f1_score:0.9170201116959248, loss_f1_score:0.08297988830407521
训练集f1_score:1.0, 测试集f1_score:0.9203767257453489, loss_f1_score:0.0796232742546511 
训练集f1_score:0.9930047208776313, 测试集f1_score:0.9130655662892675, loss_f1_score:0.0869344337107325
训练集f1_score:1.0, 测试集f1_score:0.9189223138054201, loss_f1_score:0.08107768619457989
训练集f1_score:1.0, 测试集f1_score:0.9148023758078679, loss_f1_score:0.08519762419213206
训练集f1_score:1.0, 测试集f1_score:0.9179944006960786, loss_f1_score:0.08200559930392137
训练集f1_score:1.0, 测试集f1_score:0.9289351646437458, loss_f1_score:0.07106483535625419
训练集f1_score:0.9949613424783701, 测试集f1_score:0.9168424949905195, loss_f1_score:0.08315750500948049
训练集f1_score:1.0, 测试集f1_score:0.9281459056572339, loss_f1_score:0.07185409434276613 
训练集f1_score:0.985057843980505, 测试集f1_score:0.9072691401349147, loss_f1_score:0.09273085986508534
训练集f1_score:1.0, 测试集f1_score:0.9240134358138835, loss_f1_score:0.07598656418611649 
训练集f1_score:1.0, 测试集f1_score:0.9193661487057992, loss_f1_score:0.08063385129420075 
训练集f1_score:0.9960285051705308, 测试集f1_score:0.9120716883191599, loss_f1_score:0.08792831168084014
训练集f1_score:1.0, 测试集f1_score:0.921146463718955, loss_f1_score:0.07885353628104497  
训练集f1_score:0.9998006409358186, 测试集f1_score:0.9182667116480335, loss_f1_score:0.08173328835196647
训练集f1_score:1.0, 测试集f1_score:0.9211347539389867, loss_f1_score:0.0788652460610133  
训练集f1_score:1.0, 测试集f1_score:0.9207820427538994, loss_f1_score:0.07921795724610059 
训练集f1_score:1.0, 测试集f1_score:0.9257423876791567, loss_f1_score:0.07425761232084327 
训练集f1_score:1.0, 测试集f1_score:0.9221586983297824, loss_f1_score:0.07784130167021763 
训练集f1_score:1.0, 测试集f1_score:0.9258882305068562, loss_f1_score:0.0741117694931438  
训练集f1_score:1.0, 测试集f1_score:0.9212498032441728, loss_f1_score:0.07875019675582717 
训练集f1_score:1.0, 测试集f1_score:0.9292049500198453, loss_f1_score:0.07079504998015473 
训练集f1_score:1.0, 测试集f1_score:0.9281981188283969, loss_f1_score:0.07180188117160313 
训练集f1_score:1.0, 测试集f1_score:0.9204853455880198, loss_f1_score:0.07951465441198025 
训练集f1_score:1.0, 测试集f1_score:0.9256299233258835, loss_f1_score:0.07437007667411655 
训练集f1_score:1.0, 测试集f1_score:0.925378142730627, loss_f1_score:0.07462185726937298  
训练集f1_score:1.0, 测试集f1_score:0.924234872816701, loss_f1_score:0.07576512718329897  
训练集f1_score:1.0, 测试集f1_score:0.9251101235622722, loss_f1_score:0.07488987643772782 
训练集f1_score:1.0, 测试集f1_score:0.9293946414411237, loss_f1_score:0.07060535855887629 
训练集f1_score:0.9382452765248225, 测试集f1_score:0.8942311244378818, loss_f1_score:0.10576887556211823
训练集f1_score:1.0, 测试集f1_score:0.9194809034931137, loss_f1_score:0.08051909650688627 
训练集f1_score:0.9891553541338339, 测试集f1_score:0.90941841053602, loss_f1_score:0.09058158946397998
训练集f1_score:1.0, 测试集f1_score:0.9201427057573904, loss_f1_score:0.07985729424260957 
训练集f1_score:1.0, 测试集f1_score:0.9198530476378791, loss_f1_score:0.08014695236212088 
训练集f1_score:1.0, 测试集f1_score:0.9208104365113973, loss_f1_score:0.07918956348860273 
训练集f1_score:0.9993129514194335, 测试集f1_score:0.9179142552260235, loss_f1_score:0.08208574477397645
训练集f1_score:0.9998006409358186, 测试集f1_score:0.9191648350234985, loss_f1_score:0.08083516497650145
训练集f1_score:1.0, 测试集f1_score:0.9190587031090797, loss_f1_score:0.08094129689092033 
训练集f1_score:1.0, 测试集f1_score:0.9299693184748484, loss_f1_score:0.07003068152515157 
训练集f1_score:1.0, 测试集f1_score:0.9264077945886542, loss_f1_score:0.07359220541134581 
训练集f1_score:1.0, 测试集f1_score:0.9291645062607277, loss_f1_score:0.07083549373927234 
训练集f1_score:1.0, 测试集f1_score:0.9305299545426143, loss_f1_score:0.06947004545738567 
训练集f1_score:1.0, 测试集f1_score:0.9212264145874105, loss_f1_score:0.0787735854125895  
训练集f1_score:1.0, 测试集f1_score:0.9226650230252234, loss_f1_score:0.07733497697477665 
训练集f1_score:1.0, 测试集f1_score:0.9241207734337085, loss_f1_score:0.07587922656629154 
训练集f1_score:1.0, 测试集f1_score:0.9266162462970313, loss_f1_score:0.07338375370296868 
训练集f1_score:1.0, 测试集f1_score:0.9289015037853151, loss_f1_score:0.07109849621468489 
训练集f1_score:1.0, 测试集f1_score:0.926452470167157, loss_f1_score:0.07354752983284296  
训练集f1_score:1.0, 测试集f1_score:0.9230948398574609, loss_f1_score:0.07690516014253912 
训练集f1_score:1.0, 测试集f1_score:0.9292092930785151, loss_f1_score:0.07079070692148492 
训练集f1_score:1.0, 测试集f1_score:0.927623538125419, loss_f1_score:0.07237646187458102  
训练集f1_score:1.0, 测试集f1_score:0.921699473649286, loss_f1_score:0.07830052635071405  
训练集f1_score:1.0, 测试集f1_score:0.9201681619180749, loss_f1_score:0.07983183808192507 
100%|██████████| 100/100 [59:33<00:00, 35.74s/trial, best loss: 0.06947004545738567]

五折交叉验证:

#5折交叉
def f1_score_eval(preds, valid_df):
    labels = valid_df.get_label()
    preds = np.argmax(preds.reshape(3, -1), axis = 0)
    scores = f1_score(y_true = labels, y_pred = preds, average = 'macro')
    return 'f1_score', scores, True

def sub_on_line_lgb(train_, test_, pred, label, cate_cols, split, is_shuffle = True, use_cart = False, get_prob = False):
    n_class = 3
    train_pred = np.zeros((train_.shape[0], n_class))
    test_pred = np.zeros((test_.shape[0], n_class))
    n_splits = 5
    assert split in ['kf', 'skf'], '{} Not Support this type of split way'.format(split)
    if split == 'kf':
        folds = KFold(n_splits = n_splits, shuffle = True, random_state = 1024)
        kf_way = folds.split(train_[pred])
    else:
        folds = StratifiedKFold(n_splits = n_splits, shuffle = True, random_state = 1024)
        kf_way = folds.split(train_[pred], train_[label])
    print('Use {} features ...'.format(len(pred)))
#将参数改为贝叶斯优化后的参数
    params = {'learning_rate': 0.05,
              'boosting_type': 'gbdt',
            'objective': 'multiclass',
            'metric': 'None',
            'num_leaves': 60,
            'feature_fraction':0.86,
            'bagging_fraction': 0.73,
            'bagging_freq': 5,
            'seed': 1,
            'bagging_seed': 1,
            'feature_fraction_seed': 7,
            'min_data_in_leaf': 15,
            'num_class': n_class,
            'nthread': 8,
            'verbose': -1,
            'num_boost_round': 1100,
            'max_depth': 7,}
    for n_fold, (train_idx, valid_idx) in enumerate(kf_way, start = 1):
        print('the {} training start ...'.format(n_fold))
        train_x, train_y = train_[pred].iloc[train_idx], train_[label].iloc[train_idx]
        valid_x, valid_y = train_[pred].iloc[valid_idx], train_[label].iloc[valid_idx]
        if use_cart:
            dtrain = lgb.Dataset(train_x, label = train_y, categorical_feature = cate_cols)
            dvalid = lgb.Dataset(valid_x, label = valid_y, categorical_feature = cate_cols)
        else:
            dtrain = lgb.Dataset(train_x, label = train_y)
            dvaild = lgb.Dataset(valid_x, label = valid_y)
        clf = lgb.train(params = params, train_set = dtrain, valid_sets = [dvalid], early_stopping_rounds = 100, verbose_eval = 100, feval = f1_score_eval)
        train_pred[valid_idx] = clf.predict(valid_x, num_iteration=clf.best_iteration)
        test_pred += clf.predict(test_[pred], num_iteration=clf.best_iteration) / folds.n_splits
        print(classification_report(train_[label], np.argmax(train_pred, axis=1), digits=4))
        if get_prob:
            sub_probs = ['qyxs_prob_{}'.format(q) for q in ['围网', '刺网', '拖网']]
            prob_df = pd.DataFrame(test_pred, columns=sub_probs)
            prob_df['ID'] = test_['ID'].values
        return prob_df
    else:
        test_['label'] = np.argmax(test_pred, axis=1)
        return test_[['ID', 'label']]

use_train = all_df[all_df['label'] != -1]
use_test = all_df[all_df['label'] == -1]
# use_feats = [c for c in use_train.columns if c not in ['ID', 'label']]
use_feats=model_feature
sub = sub_on_line_lgb(use_train, use_test, use_feats, 'label', [], 'kf',is_shuffle=True,use_cart=False,get_prob=False)

输出如下

Use 200 features ...
the 1 training start ...
Training until validation scores don't improve for 100 rounds
[100]	valid_0's f1_score: 0.894256
[200]	valid_0's f1_score: 0.909942
[300]	valid_0's f1_score: 0.913423
[400]	valid_0's f1_score: 0.917897
[500]	valid_0's f1_score: 0.920616
Early stopping, best iteration is:
[456]	valid_0's f1_score: 0.920717
              precision    recall  f1-score   support

           0     0.2663    0.9772    0.4185      1621
           1     0.9672    0.1739    0.2948      1018
           2     0.9539    0.1899    0.3167      4361

    accuracy                         0.3699      7000
   macro avg     0.7291    0.4470    0.3433      7000
weighted avg     0.7966    0.3699    0.3371      7000

the 2 training start ...
Training until validation scores don't improve for 100 rounds
[100]	valid_0's f1_score: 0.918357
[200]	valid_0's f1_score: 0.916436
Early stopping, best iteration is:
[140]	valid_0's f1_score: 0.92449
              precision    recall  f1-score   support

           0     0.3169    0.9562    0.4760      1621
           1     0.9531    0.3595    0.5221      1018
           2     0.9548    0.3777    0.5412      4361

    accuracy                         0.5090      7000
   macro avg     0.7416    0.5645    0.5131      7000
weighted avg     0.8068    0.5090    0.5234      7000

the 3 training start ...
Training until validation scores don't improve for 100 rounds
[100]	valid_0's f1_score: 0.915242
[200]	valid_0's f1_score: 0.927189
[300]	valid_0's f1_score: 0.930614
Early stopping, best iteration is:
[238]	valid_0's f1_score: 0.930614
              precision    recall  f1-score   support

           0     0.3946    0.9389    0.5557      1621
           1     0.9571    0.5255    0.6785      1018
           2     0.9574    0.5673    0.7125      4361

    accuracy                         0.6473      7000
   macro avg     0.7697    0.6773    0.6489      7000
weighted avg     0.8270    0.6473    0.6712      7000

the 4 training start ...
Training until validation scores don't improve for 100 rounds
[100]	valid_0's f1_score: 0.901683
[200]	valid_0's f1_score: 0.912985
[300]	valid_0's f1_score: 0.916988
[400]	valid_0's f1_score: 0.92147
[500]	valid_0's f1_score: 0.921353
Early stopping, best iteration is:
[411]	valid_0's f1_score: 0.922153
              precision    recall  f1-score   support

           0     0.5392    0.9173    0.6792      1621
           1     0.9589    0.7112    0.8167      1018
           2     0.9555    0.7640    0.8491      4361

    accuracy                         0.7919      7000
   macro avg     0.8179    0.7975    0.7817      7000
weighted avg     0.8596    0.7919    0.8051      7000

the 5 training start ...
Training until validation scores don't improve for 100 rounds
[100]	valid_0's f1_score: 0.900975
[200]	valid_0's f1_score: 0.908373
[300]	valid_0's f1_score: 0.91384
[400]	valid_0's f1_score: 0.917567
Early stopping, best iteration is:
[369]	valid_0's f1_score: 0.919843
              precision    recall  f1-score   support

           0     0.8726    0.9001    0.8861      1621
           1     0.9569    0.8949    0.9249      1018
           2     0.9586    0.9619    0.9603      4361

    accuracy                         0.9379      7000
   macro avg     0.9294    0.9190    0.9238      7000
weighted avg     0.9385    0.9379    0.9380      7000

<ipython-input-4-ad6d0ae907b5>:58: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_['label'] = np.argmax(test_pred, axis=1)

有关split()函数的介绍:https://blog.csdn.net/kancy110/article/details/74910185?utm_medium=distribute.pc_relevant_t0.none-task-blog-2%7Edefault%7EBlogCommendFromMachineLearnPai2%7Edefault-1.control&dist_request_id=&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-2%7Edefault%7EBlogCommendFromMachineLearnPai2%7Edefault-1.control

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PX4-Gazebo是指在PX4无人机开发中使用的Gazebo仿真软件。Gazebo是一个能够模拟现实世界的仿真软件,可以在其中运行PX4无人机模型。在PX4的源代码中,已经提供了PX4无人机的Gazebo模型,通过运行相应的launch文件,可以在Gazebo中模拟无人机的行为。\[1\] 在使用PX4-Gazebo时,可能会遇到一些问题。例如,可能会出现内存不足的情况。解决方法可以是为Ubuntu挂载一个8G的swap虚拟内存。具体的操作可以参考相关链接。\[2\] 另外,有时在运行PX4-Gazebo时,可能会遇到没有弹出Gazebo仿真界面的情况。这可能是因为需要使用root权限运行。可以尝试使用sudo命令来运行相应的指令。\[2\] 此外,关于PX4-Gazebo的目录结构建立,可以参考相关链接中的说明。在使用PX4源代码时,需要注意执行git submodule update --init --recursive命令,以获取所需的每个模块的代码。\[3\] 总之,PX4-Gazebo是一个用于模拟PX4无人机行为的仿真软件,使用时需要注意相关的配置和操作步骤。 #### 引用[.reference_title] - *1* [PX4无人机-Gazebo仿真实现移动物体的跟踪](https://blog.csdn.net/qq_44939973/article/details/120965458)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [PX4无人机仿真_Gazebo(1)](https://blog.csdn.net/yanwumuxi/article/details/80097294)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值