# -*- coding: utf-8 -*-
__author__ = 'gerry'
'''
XGBoost案例之蘑菇是否有毒
任务:根据蘑菇的22个特征判断蘑菇是否有毒
数据介绍:
总样本数:8124
-可食用:4208,51.8%
-有毒:3916,48.2%
-训练样本:6513
-测试样本:1611
'''
#导入需要的工具包
import xgboost as xgb
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
#将数据从文件中读出,并为XGBoost训练准备好
my_workpath = './data/'
dtrain = xgb.DMatrix(my_workpath+'agaricus.txt.train')
dtest = xgb.DMatrix(my_workpath+'agaricus.txt.test')
'''
该数据为libsvm格式的文本数据,libsvm的文件格式(稀疏特征)
-每一行为一个样本:1 3:1 9:1 19:1 21:1 30:1
* 开头的"1"是样本的标签。3,9位特征索引,1,1为特征的值
* 在两类分类中,用1表示正样本,0表示负样本,也支持用[0,1]表示概率用来做标签,表示正样本的概率
XGBoost加载的数据对象存储在对象Dmatrix中,做了存储效率和运行速度的优化
支持三种数据接口:
* libsvm.txt格式数据文件
* 常规矩阵(numpy 2D array)
* xgboost binary buffer file
'''
#设置训练参数
# specify parameters via map
param = {
'max_depth':3,
'eta':1,
'silent':0,
'objective':'binary:logistic'
}
'''
max_depth:树的最大深度,缺省值为6,取值范围:[1,∞]
eta:为了防止过拟合,更新过程用到的收缩步长,eta通过缩减特征的权重使提升计算过程更加保守。缺省值为0.3,取值范围为[0,1]
silent:0表示打印出运行时信息,1表示以缄默方式运行,缺省值为0
objective:定义学习任务以及相应的学习目标,'binary:logistic'表示二分类的逻辑回归问题,输出为概率
'''
# 模型训练
# 设置boosting迭代计算参数
num_round = 2
bst = xgb.train(param,dtrain,num_round)
'''
与scikit-learn结合
-XGBoost提供一个wrapper类,允许模型可以和scikit-learn框架中的其他分类器或者回归器一样对待
XGBoost中分类器为XGBClassifier-模型在构造时传递
'''
#bst = xgb.XGBClassifier(max_depth=2,learning_rate=1,n_estimators=num_round,silent=True,objective='binary:logistic')
#预测(训练数据上评估 )
# 模型训练好后,可以用训练好的模型对进行预测
# XGBoost预测的输出时概率,输出值是样本为第一类的概率-->将其概率值转换为0或1
train_preds = bst.predict(dtrain)
train_predictions = [round(value) for value in train_preds]
y_train = dtrain.get_label()
train_accuracy = accuracy_score(y_train,train_predictions)
print("Train Accuracy:%.2f%%"%(train_accuracy*100.0))
#预测(测试集上预测)
preds = bst.predict(dtest)
predictions = [round(value) for value in preds]
y_test = dtest.get_label()
test_accuracy = accuracy_score(y_test,predictions)
print("Test Accuracy:%.2f%%"%(test_accuracy*100.0))
# 模型可视化
'''
可视化模型中的单课树:调用XGBoost的API plot_tree()/to_graphviz()
'''
xgb.plot_tree(bst,num_trees=0,rankdir='LR')
xgb.plot_importance(bst)
plt.show()
'''
* 第一个参数为训练好的模型
* 第二个参数为要打印的树的索引(从0开始)
* 第三个参数是打印的格式
'''
__author__ =