导入库
from sklearn import datasets
from sklearn.model_selection import train_test_split
import xgboost as xgb
import numpy as np
import pandas as pd
from sklearn.metrics import precision_score, recall_score
import warnings
warnings.filterwarnings("ignore")
导入数据集+特征工程
iris = datasets.load_iris()
data = iris.data
label = iris.target
data1 = pd.DataFrame(data)
data1.columns = ['sepal_l','sepal_w','petal_l','petal_w']
print(data1.head())
label1 =pd.DataFrame(label)
label1.columns=['label']
print(label1.head())
sepal_l sepal_w petal_l petal_w
0 5.1 3.5 1.4 0.2
1 4.9 3.0 1.4 0.2
2 4.7 3.2 1.3 0.2
3 4.6 3.1 1.5 0.2
4 5.0 3.6 1.4 0.2
label
0 0
1 0
2 0
3 0
4 0
划分数据集
train_x, test_x, train_y, test_y = train_test_split(data1.values, label1.values, test_size=0.3, random_state=42)
print("训练集长度:", len(train_x))
print("测试集长度:", len(test_x))
训练集长度: 105
测试集长度: 45
模型训练与评估
test_data = xgb.DMatrix(test_x, label=test_y)
xgb_params = {
'eta': 0.3,
'silent': True,
'objective': 'multi:softprob',
'num_class': 3,
'max_depth': 3
}
num_round = 20
model = xgb.train(xgb_params,xgb.DMatrix(train_x, label=train_y), num_round)
test_pre = model.predict(test_data)
print(test_pre[:5])
test_pre_1 = np.asarray([np.argmax(row) for row in test_pre])
print("test的预测结果:",test_pre_1)
print('验证集精准率:',precision_score(test_y, test_pre_1, average='macro'))
print('验证集召回率:',recall_score(test_y, test_pre_1, average='macro'))
[[0.00650657 0.96226174 0.03123167]
[0.970643 0.02533228 0.00402478]
[0.0033913 0.00692109 0.9896876 ]
[0.00654362 0.9677424 0.02571394]
[0.00615641 0.9104776 0.083366 ]]
test的预测结果: [1 0 2 1 1 0 1 2 1 1 2 0 0 0 0 1 2 1 1 2 0 2 0 2 2 2 2 2 0 0 0 0 1 0 0 2 1
0 0 0 2 1 1 0 0]
验证集精准率: 1.0
验证集召回率: 1.0
Sklearn接口形式使用Xgboost
from xgboost import XGBClassifier
model = XGBClassifier(
learning_rate=0.01,
n_estimators=3000,
max_depth=4,
objective='binary:logistic',
seed=27
)
model.fit(train_x,train_y)
test_pre2 = model.predict(test_x)
print(test_pre2)
print('验证集精准率:',precision_score(test_y, test_pre2, average='macro'))
print('验证集召回率:',recall_score(test_y, test_pre2, average='macro'))
[1 0 2 1 1 0 1 2 1 1 2 0 0 0 0 1 2 1 1 2 0 2 0 2 2 2 2 2 0 0 0 0 1 0 0 2 1
0 0 0 2 1 1 0 0]
验证集精准率: 1.0
验证集召回率: 1.0