在用XGBClassifier做多分类问题模型存取时,采用save_model与load_model函数发现并不是很好用,因此通过pickle进行模型的存取工作,在此记录,以备后用。
import pickle
from xgboost import XGBClassifier
#train
model_xg = XGBClassifier(
n_estimators=20,
learning_rate=0.1,
max_depth=8,
subsample=0.8,
early_stopping_rounds = 50,
objective='multi:softmax',
eval_metric = 'mlogloss')
model_xg.fit(x_train, y_train,verbose=True)
# save
pickle.dump(model_xg, open("xgb.pkl", "wb"))
# load
xgb_model_loaded = pickle.load(open("xgb.pkl", "rb"))
# test
xgb_model_loaded.predict(test)