xgboost的sklearn接口使用example分析
1.多分类
加载数据
print("Iris: multiclass classification")
iris = load_iris()
训练数据和测试数据分隔
y = iris['target']
X = iris['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X):
模型训练
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
predictions = xgb_model.predict(X[test_index])
actuals = y[test_index]
print(confusion_matrix(actuals, predictions))
模型保存
print("Pickling sklearn API models")
# 以二进制的模式打开
pickle.dump(clf, open("best_boston.pkl", "wb"))
clf2 = pickle.load(open("best_boston.pkl", "rb"))
print(np.allclose(clf.predict(X), clf2.predict(X)))