from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import mean_squared_error
import lightgbm as lgb
读取数据
iris = load_iris()
data = iris.data
target = iris.target
data
array([[5.1, 3.5, 1.4, 0.2], [4.9, 3. , 1.4, 0.2], [4.7, 3.2, 1.3, 0.2], [4.6, 3.1, 1.5, 0.2], [5. , 3.6, 1.4, 0.2], [5.4, 3.9, 1.7, 0.4], [4.6, 3.4, 1.4, 0.3], [5. , 3.4, 1.5, 0.2], [4.4, 2.9, 1.4, 0.2], [4.9, 3.1, 1.5, 0.1], [5.4, 3.7, 1.5, 0.2], [4.8, 3.4, 1.6, 0.2], [4.8, 3. , 1.4, 0.1], [4.3, 3. , 1.1, 0.1], [5.8, 4. , 1.2, 0.2], [5.7, 4.4, 1.5, 0.4], [5.4, 3.9, 1.3, 0.4], [5.1, 3.5, 1.4, 0.3], [5.7, 3.8, 1.7, 0.3], [5.1, 3.8, 1.5, 0.3], [5.4, 3.4, 1.7, 0.2], [5.1, 3.7, 1.5, 0.4], [4.6, 3.6, 1. , 0.2], [5.1, 3.3, 1.7, 0.5], [4.8, 3.4, 1.9, 0.2], [5. , 3. , 1.6, 0.2], [5. , 3.4, 1.6, 0.4], [5.2, 3.5, 1.5, 0.2], [5.2, 3.4, 1.4, 0.2], [4.7, 3.2, 1.6, 0.2], [4.8, 3.1, 1.6, 0.2], [5.4, 3.4, 1.5, 0.4], [5.2, 4.1, 1.5, 0.1], [5.5, 4.2, 1.4, 0.2], [4.9, 3.1, 1.5, 0.2], [5. , 3.2, 1.2, 0.2], [5.5, 3.5, 1.3, 0.2], [4.9, 3.6, 1.4, 0.1], [4.4, 3. , 1.3, 0.2], [5.1, 3.4, 1.5, 0.2], [5. , 3.5, 1.3, 0.3], [4.5, 2.3, 1.3, 0.3], [4.4, 3.2, 1.3, 0.2], [5. , 3.5, 1.6, 0.6], [5.1, 3.8, 1.9, 0.4], [4.8, 3. , 1.4, 0.3], [5.1, 3.8, 1.6, 0.2], [4.6, 3.2, 1.4, 0.2], [5.3, 3.7, 1.5, 0.2], [5. , 3.3, 1.4, 0.2], [7. , 3.2, 4.7, 1.4], [6.4, 3.2, 4.5, 1.5], [6.9, 3.1, 4.9, 1.5], [5.5, 2.3, 4. , 1.3], [6.5, 2.8, 4.6, 1.5], [5.7, 2.8, 4.5, 1.3], [6.3, 3.3, 4.7, 1.6], [4.9, 2.4, 3.3, 1. ], [6.6, 2.9, 4.6, 1.3], [5.2, 2.7, 3.9, 1.4], [5. , 2. , 3.5, 1. ], [5.9, 3. , 4.2, 1.5], [6. , 2.2, 4. , 1. ], [6.1, 2.9, 4.7, 1.4], [5.6, 2.9, 3.6, 1.3], [6.7, 3.1, 4.4, 1.4], [5.6, 3. , 4.5, 1.5], [5.8, 2.7, 4.1, 1. ], [6.2, 2.2, 4.5, 1.5], [5.6, 2.5, 3.9, 1.1], [5.9, 3.2, 4.8, 1.8], [6.1, 2.8, 4. , 1.3], [6.3, 2.5, 4.9, 1.5], [6.1, 2.8, 4.7, 1.2], [6.4, 2.9, 4.3, 1.3], [6.6, 3. , 4.4, 1.4], [6.8, 2.8, 4.8, 1.4], [6.7, 3. , 5. , 1.7], [6. , 2.9, 4.5, 1.5], [5.7, 2.6, 3.5, 1. ], [5.5, 2.4, 3.8, 1.1], [5.5, 2.4, 3.7, 1. ], [5.8, 2.7, 3.9, 1.2], [6. , 2.7, 5.1, 1.6], [5.4, 3. , 4.5, 1.5], [6. , 3.4, 4.5, 1.6], [6.7, 3.1, 4.7, 1.5], [6.3, 2.3, 4.4, 1.3], [5.6, 3. , 4.1, 1.3], [5.5, 2.5, 4. , 1.3], [5.5, 2.6, 4.4, 1.2], [6.1, 3. , 4.6, 1.4], [5.8, 2.6, 4. , 1.2], [5. , 2.3, 3.3, 1. ], [5.6, 2.7, 4.2, 1.3], [5.7, 3. , 4.2, 1.2], [5.7, 2.9, 4.2, 1.3], [6.2, 2.9, 4.3, 1.3], [5.1, 2.5, 3. , 1.1], [5.7, 2.8, 4.1, 1.3], [6.3, 3.3, 6. , 2.5], [5.8, 2.7, 5.1, 1.9], [7.1, 3. , 5.9, 2.1], [6.3, 2.9, 5.6, 1.8], [6.5, 3. , 5.8, 2.2], [7.6, 3. , 6.6, 2.1], [4.9, 2.5, 4.5, 1.7], [7.3, 2.9, 6.3, 1.8], [6.7, 2.5, 5.8, 1.8], [7.2, 3.6, 6.1, 2.5], [6.5, 3.2, 5.1, 2. ], [6.4, 2.7, 5.3, 1.9], [6.8, 3. , 5.5, 2.1], [5.7, 2.5, 5. , 2. ], [5.8, 2.8, 5.1, 2.4], [6.4, 3.2, 5.3, 2.3], [6.5, 3. , 5.5, 1.8], [7.7, 3.8, 6.7, 2.2], [7.7, 2.6, 6.9, 2.3], [6. , 2.2, 5. , 1.5], [6.9, 3.2, 5.7, 2.3], [5.6, 2.8, 4.9, 2. ], [7.7, 2.8, 6.7, 2. ], [6.3, 2.7, 4.9, 1.8], [6.7, 3.3, 5.7, 2.1], [7.2, 3.2, 6. , 1.8], [6.2, 2.8, 4.8, 1.8], [6.1, 3. , 4.9, 1.8], [6.4, 2.8, 5.6, 2.1], [7.2, 3. , 5.8, 1.6], [7.4, 2.8, 6.1, 1.9], [7.9, 3.8, 6.4, 2. ], [6.4, 2.8, 5.6, 2.2], [6.3, 2.8, 5.1, 1.5], [6.1, 2.6, 5.6, 1.4], [7.7, 3. , 6.1, 2.3], [6.3, 3.4, 5.6, 2.4], [6.4, 3.1, 5.5, 1.8], [6. , 3. , 4.8, 1.8], [6.9, 3.1, 5.4, 2.1], [6.7, 3.1, 5.6, 2.4], [6.9, 3.1, 5.1, 2.3], [5.8, 2.7, 5.1, 1.9], [6.8, 3.2, 5.9, 2.3], [6.7, 3.3, 5.7, 2.5], [6.7, 3. , 5.2, 2.3], [6.3, 2.5, 5. , 1.9], [6.5, 3. , 5.2, 2. ], [6.2, 3.4, 5.4, 2.3], [5.9, 3. , 5.1, 1.8]])
target
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
数据基本处理
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2)
模型训练
模型基本训练
gbm = lgb.LGBMRegressor(objective="regression", learning_rate=0.05, n_estimators=20)
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], eval_metric="l1", early_stopping_rounds=3)
gbm.score(X_test, y_test)
[1] valid_0's l1: 0.653531 valid_0's l2: 0.626219 Training until validation scores don't improve for 3 rounds [2] valid_0's l1: 0.626209 valid_0's l2: 0.57348 [3] valid_0's l1: 0.60108 valid_0's l2: 0.525437 [4] valid_0's l1: 0.577988 valid_0's l2: 0.482521 [5] valid_0's l1: 0.555301 valid_0's l2: 0.443297 [6] valid_0's l1: 0.534806 valid_0's l2: 0.408881 [7] valid_0's l1: 0.510834 valid_0's l2: 0.372852 [8] valid_0's l1: 0.491373 valid_0's l2: 0.344015 [9] valid_0's l1: 0.469678 valid_0's l2: 0.314384 [10] valid_0's l1: 0.451908 valid_0's l2: 0.290418 [11] valid_0's l1: 0.433932 valid_0's l2: 0.268274 [12] valid_0's l1: 0.414266 valid_0's l2: 0.245211 [13] valid_0's l1: 0.398027 valid_0's l2: 0.227095 [14] valid_0's l1: 0.380293 valid_0's l2: 0.208076 [15] valid_0's l1: 0.365621 valid_0's l2: 0.193252 [16] valid_0's l1: 0.34957 valid_0's l2: 0.177498 [17] valid_0's l1: 0.336313 valid_0's l2: 0.16537 [18] valid_0's l1: 0.321785 valid_0's l2: 0.152308 [19] valid_0's l1: 0.310088 valid_0's l2: 0.142386 [20] valid_0's l1: 0.298266 valid_0's l2: 0.131543 Did not meet early stopping. Best iteration is: [20] valid_0's l1: 0.298266 valid_0's l2: 0.131543 0.7578964818630016
通过网格搜索进行训练
estimators = lgb.LGBMRegressor(num_leaves=31)
param_grid = {
"learning_rate": [0.01, 0.1, 1],
"n_estmators":[20, 40, 60, 80]
}
gbm = GridSearchCV(estimators, param_grid, cv=5)
gbm.fit(X_train, y_train)
GridSearchCV(cv=5, error_score=nan, estimator=LGBMRegressor(boosting_type='gbdt', class_weight=None, colsample_bytree=1.0, importance_type='split', learning_rate=0.1, max_depth=-1, min_child_samples=20, min_child_weight=0.001, min_split_gain=0.0, n_estimators=100, n_jobs=-1, num_leaves=31, objective=None, random_state=None, reg_alpha=0.0, reg_lambda=0.0, silent=True, subsample=1.0, subsample_for_bin=200000, subsample_freq=0), iid='deprecated', n_jobs=None, param_grid={'learning_rate': [0.01, 0.1, 1], 'n_estmators': [20, 40, 60, 80]}, pre_dispatch='2*n_jobs', refit=True, return_train_score=False, scoring=None, verbose=0)
gbm.best_params_
{‘learning_rate’: 0.1, ‘n_estmators’: 20}
gbm = lgb.LGBMRegressor(objective="regression", learning_rate=0.1, n_estimators=20)
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], eval_metric="l1", early_stopping_rounds=3)
gbm.score(X_test, y_test)
[1] valid_0's l1: 0.625261 valid_0's l2: 0.571453 Training until validation scores don't improve for 3 rounds [2] valid_0's l1: 0.574385 valid_0's l2: 0.477181 [3] valid_0's l1: 0.531459 valid_0's l2: 0.403427 [4] valid_0's l1: 0.483888 valid_0's l2: 0.33428 [5] valid_0's l1: 0.447306 valid_0's l2: 0.284716 [6] valid_0's l1: 0.413883 valid_0's l2: 0.243537 [7] valid_0's l1: 0.377047 valid_0's l2: 0.203656 [8] valid_0's l1: 0.348048 valid_0's l2: 0.175576 [9] valid_0's l1: 0.318049 valid_0's l2: 0.148479 [10] valid_0's l1: 0.29463 valid_0's l2: 0.129983 [11] valid_0's l1: 0.27226 valid_0's l2: 0.111468 [12] valid_0's l1: 0.2489 valid_0's l2: 0.0960426 [13] valid_0's l1: 0.230634 valid_0's l2: 0.0833998 [14] valid_0's l1: 0.216687 valid_0's l2: 0.0759234 [15] valid_0's l1: 0.1993 valid_0's l2: 0.0670385 [16] valid_0's l1: 0.188099 valid_0's l2: 0.0622206 [17] valid_0's l1: 0.178022 valid_0's l2: 0.058299 [18] valid_0's l1: 0.168954 valid_0's l2: 0.0551119 [19] valid_0's l1: 0.158303 valid_0's l2: 0.0505529 [20] valid_0's l1: 0.149623 valid_0's l2: 0.0466022 Did not meet early stopping. Best iteration is: [20] valid_0's l1: 0.149623 valid_0's l2: 0.0466022 0.9142290887795029