1.1 XGBoost的介绍
XGBoost是2016年由华盛顿大学陈天奇老师带领开发的一个可扩展机器学习系统。严格意义上讲XGBoost并不是一种模型,而是一个可供用户轻松解决分类、回归或排序问题的软件包。它内部实现了梯度提升树(GBDT)模型,并对模型中的算法进行了诸多优化,在取得高精度的同时又保持了极快的速度,在一段时间内成为了国内外数据挖掘、机器学习领域中的大规模杀伤性武器。
更重要的是,XGBoost在系统优化和机器学习原理方面都进行了深入的考虑。毫不夸张的讲,XGBoost提供的可扩展性,可移植性与准确性推动了机器学习计算限制的上限,该系统在单台机器上运行速度比当时流行解决方案快十倍以上,甚至在分布式系统中可以处理十亿级的数据。
XGBoost的主要优点:
- 简单易用。相对其他机器学习库,用户可以轻松使用XGBoost并获得相当不错的效果。
- 高效可扩展。在处理大规模数据集时速度快效果好,对内存等硬件资源要求不高。
- 鲁棒性强。相对于深度学习模型不需要精细调参便能取得接近的效果。
- XGBoost内部实现提升树模型,可以自动处理缺失值。
XGBoost的主要缺点:
- 相对于深度学习模型无法对时空位置建模,不能很好地捕获图像、语音、文本等高维数据。
- 在拥有海量训练数据,并能找到合适的深度学习模型时,深度学习的精度可以遥遥领先XGBoost。
1.2 XGboost的应用
XGBoost在机器学习与数据挖掘领域有着极为广泛的应用。据统计在2015年Kaggle平台上29个获奖方案中,17只队伍使用了XGBoost;在2015年KDD-Cup中,前十名的队伍均使用了XGBoost,且集成其他模型比不上调节XGBoost的参数所带来的提升。这些实实在在的例子都表明,XGBoost在各种问题上都可以取得非常好的效果。
同时,XGBoost还被成功应用在工业界与学术界的各种问题中。例如商店销售额预测、高能物理事件分类、web文本分类;用户行为预测、运动检测、广告点击率预测、恶意软件分类、灾害风险预测、在线课程退学率预测。虽然领域相关的数据分析和特性工程在这些解决方案中也发挥了重要作用,但学习者与实践者对XGBoost的一致选择表明了这一软件包的影响力与重要性。
2. 实验室手册
2.1 学习目标
- 了解 XGBoost 的参数与相关知识
- 掌握 XGBoost 的Python调用并将其运用到天气数据集预测
2.2 代码流程
Part1 基于天气数据集的XGBoost分类实践
- Step1: 库函数导入
- Step2: 数据读取/载入
- Step3: 数据信息简单查看
- Step4: 可视化描述
- Step5: 对离散变量进行编码
- Step6: 利用 XGBoost 进行训练与预测
- Step7: 利用 XGBoost 进行特征选择
- Step8: 通过调整参数获得更好的效果
2.3 算法实战
2.3.1 基于天气数据集的XGBoost分类实战
在实践的最开始,我们首先需要导入一些基础的函数库包括:numpy (Python进行科学计算的基础软件包),pandas(pandas是一种快速,强大,灵活且易于使用的开源数据分析和处理工具),matplotlib和seaborn绘图。
1
#导入需要用到的数据集
2
!wget https://tianchi-media.oss-cn-beijing.aliyuncs.com/DSW/7XGBoost/train.csv
--2020-08-22 17:18:54-- https://tianchi-media.oss-cn-beijing.aliyuncs.com/DSW/7XGBoost/train.csv Resolving tianchi-media.oss-cn-beijing.aliyuncs.com (tianchi-media.oss-cn-beijing.aliyuncs.com)... 47.95.85.21 Connecting to tianchi-media.oss-cn-beijing.aliyuncs.com (tianchi-media.oss-cn-beijing.aliyuncs.com)|47.95.85.21|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 11476379 (11M) [text/csv] Saving to: ‘train.csv’ 100%[======================================>] 11,476,379 23.2MB/s in 0.5s 2020-08-22 17:18:55 (23.2 MB/s) - ‘train.csv’ saved [11476379/11476379]
Step1:函数库导入
1
## 基础函数库
2
import numpy as np
3
import pandas as pd
4
5
## 绘图函数库
6
import matplotlib.pyplot as plt
7
import seaborn as sns
本次我们选择天气数据集进行方法的尝试训练,现在有一些由气象站提供的每日降雨数据,我们需要根据历史降雨数据来预测明天会下雨的概率。样例涉及到的测试集数据test.csv与train.csv的格式完全相同,但其RainTomorrow未给出,为预测变量。
数据的各个特征描述如下:
特征名称 | 意义 | 取值范围 |
---|---|---|
Date | 日期 | 字符串 |
Location | 气象站的地址 | 字符串 |
MinTemp | 最低温度 | 实数 |
MaxTemp | 最高温度 | 实数 |
Rainfall | 降雨量 | 实数 |
Evaporation | 蒸发量 | 实数 |
Sunshine | 光照时间 | 实数 |
WindGustDir | 最强的风的方向 | 字符串 |
WindGustSpeed | 最强的风的速度 | 实数 |
WindDir9am | 早上9点的风向 | 字符串 |
WindDir3pm | 下午3点的风向 | 字符串 |
WindSpeed9am | 早上9点的风速 | 实数 |
WindSpeed3pm | 下午3点的风速 | 实数 |
Humidity9am | 早上9点的湿度 | 实数 |
Humidity3pm | 下午3点的湿度 | 实数 |
Pressure9am | 早上9点的大气压 | 实数 |
Pressure3pm | 早上3点的大气压 | 实数 |
Cloud9am | 早上9点的云指数 | 实数 |
Cloud3pm | 早上3点的云指数 | 实数 |
Temp9am | 早上9点的温度 | 实数 |
Temp3pm | 早上3点的温度 | 实数 |
RainToday | 今天是否下雨 | No,Yes |
RainTomorrow | 明天是否下雨 | No,Yes |
Step2:数据读取/载入
1
## 我们利用Pandas自带的read_csv函数读取并转化为DataFrame格式
2
3
data = pd.read_csv('train.csv')
Step3:数据信息简单查看
1
## 利用.info()查看数据的整体信息
2
data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 106644 entries, 0 to 106643 Data columns (total 23 columns): Date 106644 non-null object Location 106644 non-null object MinTemp 106183 non-null float64 MaxTemp 106413 non-null float64 Rainfall 105610 non-null float64 Evaporation 60974 non-null float64 Sunshine 55718 non-null float64 WindGustDir 99660 non-null object WindGustSpeed 99702 non-null float64 WindDir9am 99166 non-null object WindDir3pm 103788 non-null object WindSpeed9am 105643 non-null float64 WindSpeed3pm 104653 non-null float64 Humidity9am 105327 non-null float64 Humidity3pm 103932 non-null float64 Pressure9am 96107 non-null float64 Pressure3pm 96123 non-null float64 Cloud9am 66303 non-null float64 Cloud3pm 63691 non-null float64 Temp9am 105983 non-null float64 Temp3pm 104599 non-null float64 RainToday 105610 non-null object RainTomorrow 106644 non-null object dtypes: float64(16), object(7) memory usage: 18.7+ MB
1
## 进行简单的数据查看,我们可以利用 .head() 头部.tail()尾部
2
data.head()
[8]:
, , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , ,
Date | Location | MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustDir | WindGustSpeed | WindDir9am | ... | Humidity9am | Humidity3pm | Pressure9am | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | RainToday | RainTomorrow | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2012/1/19 | MountGinini | 12.1 | 23.1 | 0.0 | NaN | NaN | W | 30.0 | N | ... | 60.0 | 54.0 | NaN | NaN | NaN | NaN | 17.0 | 22.0 | No | No |
1 | 2015/4/13 | Nhil | 10.2 | 24.7 | 0.0 | NaN | NaN | E | 39.0 | E | ... | 63.0 | 33.0 | 1021.9 | 1017.9 | NaN | NaN | 12.5 | 23.7 | No | Yes |
2 | 2010/8/5 | Nuriootpa | -0.4 | 11.0 | 3.6 | 0.4 | 1.6 | W | 28.0 | N | ... | 97.0 | 78.0 | 1025.9 | 1025.3 | 7.0 | 8.0 | 3.9 | 9.0 | Yes | No |
3 | 2013/3/18 | Adelaide | 13.2 | 22.6 | 0.0 | 15.4 | 11.0 | SE | 44.0 | E | ... | 47.0 | 34.0 | 1025.0 | 1022.2 | NaN | NaN | 15.2 | 21.7 | No | No |
4 | 2011/2/16 | Sale | 14.1 | 28.6 | 0.0 | 6.6 | 6.7 | E | 28.0 | NE | ... | 92.0 | 42.0 | 1018.0 | 1014.1 | 4.0 | 7.0 | 19.1 | 28.2 | No | No |
,
5 rows × 23 columns
,
这里我们发现数据集中存在NaN,一般的我们认为NaN在数据集中代表了缺失值,可能是数据采集或处理时产生的一种错误。这里我们采用-1将缺失值进行填补,还有其他例如“中位数填补、平均数填补”的缺失值处理方法有兴趣的同学也可以尝试。
1
data = data.fillna(-1)
1
data.tail()
[10]:
, , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , ,