Python机器学习之XGBoost从入门到实战(代码实现)

# -*- coding: utf-8 -*-
__author__ = 'gerry'

'''
    XGBoost案例之蘑菇是否有毒
        任务:根据蘑菇的22个特征判断蘑菇是否有毒
        数据介绍:
            总样本数:8124
                -可食用:4208,51.8%
                -有毒:3916,48.2%

                -训练样本:6513
                -测试样本:1611
'''
#导入需要的工具包
import xgboost as xgb
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt


#将数据从文件中读出,并为XGBoost训练准备好

my_workpath = './data/'
dtrain = xgb.DMatrix(my_workpath+'agaricus.txt.train')
dtest = xgb.DMatrix(my_workpath+'agaricus.txt.test')

'''
    该数据为libsvm格式的文本数据,libsvm的文件格式(稀疏特征)
    -每一行为一个样本:1 3:1 9:1 19:1 21:1 30:1
     * 开头的"1"是样本的标签。3,9位特征索引,1,1为特征的值
     * 在两类分类中,用1表示正样本,0表示负样本,也支持用[0,1]表示概率用来做标签,表示正样本的概率
    XGBoost加载的数据对象存储在对象Dmatrix中,做了存储效率和运行速度的优化
    支持三种数据接口:
        * libsvm.txt格式数据文件
        * 常规矩阵(numpy 2D array)
        * xgboost binary buffer file
'''

#设置训练参数
# specify parameters via map
param = {
    'max_depth':3,
    'eta':1,
    'silent':0,
    'objective':'binary:logistic'
}

'''
    max_depth:树的最大深度,缺省值为6,取值范围:[1,∞]
    eta:为了防止过拟合,更新过程用到的收缩步长,eta通过缩减特征的权重使提升计算过程更加保守。缺省值为0.3,取值范围为[0,1]
    silent:0表示打印出运行时信息,1表示以缄默方式运行,缺省值为0
    objective:定义学习任务以及相应的学习目标,'binary:logistic'表示二分类的逻辑回归问题,输出为概率

'''

# 模型训练
# 设置boosting迭代计算参数
num_round = 2
bst = xgb.train(param,dtrain,num_round)

'''
    与scikit-learn结合
    -XGBoost提供一个wrapper类,允许模型可以和scikit-learn框架中的其他分类器或者回归器一样对待
    XGBoost中分类器为XGBClassifier-模型在构造时传递



'''

#bst = xgb.XGBClassifier(max_depth=2,learning_rate=1,n_estimators=num_round,silent=True,objective='binary:logistic')

#预测(训练数据上评估 )
# 模型训练好后,可以用训练好的模型对进行预测
# XGBoost预测的输出时概率,输出值是样本为第一类的概率-->将其概率值转换为0或1

train_preds = bst.predict(dtrain)
train_predictions = [round(value) for value in train_preds]
y_train = dtrain.get_label()
train_accuracy = accuracy_score(y_train,train_predictions)

print("Train Accuracy:%.2f%%"%(train_accuracy*100.0))


#预测(测试集上预测)

preds = bst.predict(dtest)
predictions = [round(value) for value in preds]
y_test = dtest.get_label()
test_accuracy = accuracy_score(y_test,predictions)

print("Test Accuracy:%.2f%%"%(test_accuracy*100.0))


# 模型可视化
'''
    可视化模型中的单课树:调用XGBoost的API plot_tree()/to_graphviz()

'''
xgb.plot_tree(bst,num_trees=0,rankdir='LR')
xgb.plot_importance(bst)
plt.show()
'''
    * 第一个参数为训练好的模型
    * 第二个参数为要打印的树的索引(从0开始)
    * 第三个参数是打印的格式
'''
# -*- coding: utf-8 -*-
__author__ = 
  • 0
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值