【XGBoost 多分类】XGBoost解决多分类问题

下面将以一个例子来讲解 XGBoost 解决多分类问题。

1、下载数据集,数据集我们采用小麦种子数据集,该数据集有3类,已知小麦种子包含 7个特征,分别为面积,周长,紧凑度,仔粒长度,仔粒宽度,不对称系数,仔粒腹沟长度,小麦类别为1,2,3

linux --下载数据集:

wget https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt

window–下载数据集:
将地址复制到浏览器即可下载

https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt

文件名叫 seeds_dataset.txt 数据类似如下:

15.26	14.84	0.871	5.763	3.312	2.221	5.22	1
14.88	14.57	0.8811	5.554	3.333	1.018	4.956	1
14.29	14.09	0.905	5.291	3.337	2.699	4.825	1
13.84	13.94	0.8955	5.324	3.379	2.259	4.805	1
16.14	14.99	0.9034	5.658	3.562	1.355	5.175	1
14.38	14.21	0.8951	5.386	3.312	2.462	4.956	1
14.69	14.49	0.8799	5.563	3.259	3.586	5.219	1
14.11	14.1	0.8911	5.42	3.302	2.7		5		1
16.63	15.46	0.8747	6.053	3.465	2.04	5.877	1
16.44	15.25	0.888	5.884	3.505	1.969	5.533	1
15.26	14.85	0.8696	5.714	3.242	4.543	5.314	1
14.03	14.16	0.8796	5.438	3.201	1.717	5.001	1
13.89	14.02	0.888	5.439	3.199	3.986	4.738	1
.....
# -*- coding: utf-8 -*-

import pandas as pd
import xgboost as xgb
import numpy as np
import warnings
warnings.filterwarnings('ignore')
from sklearn.model_selection import train_test_split

data_path='./datasets/seeds_dataset.txt'
data=pd.read_csv(data_path,header=None,sep='\s+',converters={7:lambda x:int(x)-1})
data.rename(columns={7:'lable'},inplace=True)
print(data)

# # # 生产一个随机数并选择小于0.8的数据
# mask=np.random.rand(len(data))<0.8
# train=data[mask]
# test=data[~mask]
#
# # 生产DMatrix
# xgb_train=xgb.DMatrix(train.iloc[:,:6],label=train.lable)
# xgb_test=xgb.DMatrix(test.iloc[:,:6],label=test.lable)



X=data.iloc[:,:6]
Y=data.iloc[:,7]
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.25, random_state=100)

xgb_train=xgb.DMatrix(X_train,label=y_train)
xgb_test=xgb.DMatrix(X_test,label=y_test)



# 设置模型参数

params={
    'objective':'multi:softmax',
    'eta':0.1,
    'max_depth':5,
    'num_class':3
}

watchlist=[(xgb_train,'train'),(xgb_test,'test')]
# 设置训练轮次,这里设置60轮
num_round=60
bst=xgb.train(params,xgb_train,num_round,watchlist)

# 模型预测

pred=bst.predict(xgb_test)
print(pred)

#模型评估

# error_rate=np.sum(pred!=test.lable)/test.lable.shape[0]
error_rate=np.sum(pred!=y_test)/y_test.shape[0]

print('测试集错误率(softmax):{}'.format(error_rate))

accuray=1-error_rate
print('测试集准确率:%.4f' %accuray)


# 模型保存
bst.save_model("./datasets/002.model")


# 模型加载
bst=xgb.Booster()
bst.load_model("./datasets/002.model")
pred=bst.predict(xgb_test)
print(pred)

运行结果:

       0      1       2      3      4      5      6  lable
0    15.26  14.84  0.8710  5.763  3.312  2.221  5.220      0
1    14.88  14.57  0.8811  5.554  3.333  1.018  4.956      0
2    14.29  14.09  0.9050  5.291  3.337  2.699  4.825      0
3    13.84  13.94  0.8955  5.324  3.379  2.259  4.805      0
4    16.14  14.99  0.9034  5.658  3.562  1.355  5.175      0
..     ...    ...     ...    ...    ...    ...    ...    ...
205  12.19  13.20  0.8783  5.137  2.981  3.631  4.870      2
206  11.23  12.88  0.8511  5.140  2.795  4.325  5.003      2
207  13.20  13.66  0.8883  5.236  3.232  8.315  5.056      2
208  11.84  13.21  0.8521  5.175  2.836  3.598  5.044      2
209  12.30  13.34  0.8684  5.243  2.974  5.637  5.063      2

[210 rows x 8 columns]
[0]	train-merror:0.012739	test-merror:0.075472
[1]	train-merror:0.012739	test-merror:0.056604
[2]	train-merror:0.006369	test-merror:0.075472
[3]	train-merror:0.012739	test-merror:0.075472
[4]	train-merror:0.006369	test-merror:0.075472
[5]	train-merror:0	test-merror:0.075472
[6]	train-merror:0	test-merror:0.075472
[7]	train-merror:0	test-merror:0.075472
[8]	train-merror:0	test-merror:0.075472
[9]	train-merror:0	test-merror:0.075472
[10]	train-merror:0	test-merror:0.075472
[11]	train-merror:0	test-merror:0.075472
[12]	train-merror:0	test-merror:0.075472
[13]	train-merror:0	test-merror:0.075472
[14]	train-merror:0	test-merror:0.075472
[15]	train-merror:0	test-merror:0.075472
[16]	train-merror:0	test-merror:0.075472
[17]	train-merror:0	test-merror:0.075472
[18]	train-merror:0	test-merror:0.075472
[19]	train-merror:0	test-merror:0.075472
[20]	train-merror:0	test-merror:0.075472
[21]	train-merror:0	test-merror:0.075472
[22]	train-merror:0	test-merror:0.075472
[23]	train-merror:0	test-merror:0.075472
[24]	train-merror:0	test-merror:0.075472
[25]	train-merror:0	test-merror:0.075472
[26]	train-merror:0	test-merror:0.075472
[27]	train-merror:0	test-merror:0.075472
[28]	train-merror:0	test-merror:0.075472
[29]	train-merror:0	test-merror:0.075472
[30]	train-merror:0	test-merror:0.075472
[31]	train-merror:0	test-merror:0.075472
[32]	train-merror:0	test-merror:0.075472
[33]	train-merror:0	test-merror:0.075472
[34]	train-merror:0	test-merror:0.075472
[35]	train-merror:0	test-merror:0.075472
[36]	train-merror:0	test-merror:0.075472
[37]	train-merror:0	test-merror:0.075472
[38]	train-merror:0	test-merror:0.075472
[39]	train-merror:0	test-merror:0.075472
[40]	train-merror:0	test-merror:0.075472
[41]	train-merror:0	test-merror:0.075472
[42]	train-merror:0	test-merror:0.075472
[43]	train-merror:0	test-merror:0.075472
[44]	train-merror:0	test-merror:0.075472
[45]	train-merror:0	test-merror:0.075472
[46]	train-merror:0	test-merror:0.075472
[47]	train-merror:0	test-merror:0.075472
[48]	train-merror:0	test-merror:0.075472
[49]	train-merror:0	test-merror:0.075472
[50]	train-merror:0	test-merror:0.075472
[51]	train-merror:0	test-merror:0.075472
[52]	train-merror:0	test-merror:0.075472
[53]	train-merror:0	test-merror:0.075472
[54]	train-merror:0	test-merror:0.075472
[55]	train-merror:0	test-merror:0.075472
[56]	train-merror:0	test-merror:0.075472
[57]	train-merror:0	test-merror:0.075472
[58]	train-merror:0	test-merror:0.075472
[59]	train-merror:0	test-merror:0.075472
[0. 2. 2. 2. 1. 0. 2. 1. 2. 2. 1. 1. 1. 0. 1. 0. 2. 1. 0. 2. 1. 0. 1. 0.
 2. 2. 1. 2. 0. 0. 2. 0. 2. 2. 1. 2. 2. 2. 2. 1. 1. 0. 1. 1. 0. 2. 0. 2.
 2. 1. 0. 2. 2.]
测试集错误率(softmax):0.07547169811320754
测试集准确率:0.9245
[0. 2. 2. 2. 1. 0. 2. 1. 2. 2. 1. 1. 1. 0. 1. 0. 2. 1. 0. 2. 1. 0. 1. 0.
 2. 2. 1. 2. 0. 0. 2. 0. 2. 2. 1. 2. 2. 2. 2. 1. 1. 0. 1. 1. 0. 2. 0. 2.
 2. 1. 0. 2. 2.]

Process finished with exit code 0

发布了670 篇原创文章 · 获赞 834 · 访问量 200万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 编程工作室 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览