【Spark ML】第 3 章:监督学习

 🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎

📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃

🎁欢迎各位→点赞👍 + 收藏⭐️ + 留言📝​

📣系列专栏 - 机器学习【ML】 自然语言处理【NLP】  深度学习【DL】

 🖍foreword

✔说明⇢本人讲解主要包括Python、机器学习(ML)、深度学习(DL)、自然语言处理(NLP)等内容。

如果你对这个系列感兴趣的话,可以关注订阅哟👋

文章目录

分类

二元分类

多类分类

多标签分类

Spark MLlib 分类算法

逻辑回归

支持向量机

朴素贝叶斯

多层感知器

决策树

随机森林

梯度提升树

随机森林与梯度提升树

第三方分类和回归算法

使用逻辑回归进行多类分类

随机森林的流失预测

参数

例子

功能重要性(Feature Importance)

极限梯度提升与 XGBoost4J-Spark

参数

例子

LightGBM:来自微软的快速梯度提升

参数

例子

使用朴素贝叶斯进行情感分析

例子

回归

简单线性回归

例子

多元回归与 XGBoost4J-Spark

例子


监督学习是一项机器学习任务,它使用训练数据集进行预测。监督学习可以分为分类或回归。回归用于预测连续值(如价格、温度或距离),而分类用于预测类别,如是或否、垃圾邮件与非垃圾邮件、恶性或良性。

分类

分类可能是最常见的监督式机器学习任务。您很可能已经遇到过使用分类的应用程序,甚至没有意识到它。流行的用例包括医疗诊断、定向营销、垃圾邮件检测、信用风险预测和情绪分析,仅举几例。有三种类型的分类任务。

二元分类

如果只有两个类别,则任务是二元分类或二项式分类。例如,使用二元分类算法进行垃圾邮件检测时,输出变量可以有两个类别:垃圾邮件或非垃圾邮件。为了检测癌症,类别可以是恶性或良性。对于有针对性的营销,预测某人购买牛奶等商品的可能性,类别可以是“是”或“否”。

多类分类

多类或多项式分类任务有三个或更多类别。例如,要预测天气状况,您可能有五个类别:雨天、多云、晴天、下雪和刮风。为了扩展我们的目标营销示例,可以使用多类分类来预测客户是否更有可能购买全脂牛奶、低脂牛奶、低脂牛奶或脱脂牛奶。

多标签分类

在多标签分类中,可以为每个观测值分配多个类别。相反,在多类分类中,只能将一个类别分配给观测值。使用我们的目标营销示例,多标签分类不仅用于预测客户是否更有可能购买牛奶,还用于预测其他商品,例如饼干,黄油,热狗或面包。

Spark MLlib 分类算法

Spark MLlib 包括几种用于分类的算法。我将讨论最流行的算法,并提供易于理解的代码示例,这些示例基于我们在第 2 章中介绍的内容。在本章的后面,我将讨论更高级的下一代算法,如 XGBoost 和 LightGBM。

逻辑回归

逻辑回归是预测概率的线性分类器。它使用逻辑(sigmoid)函数将其输出转换为可映射到两个(二进制)类的概率值。通过多项逻辑 (softmax) 回归支持多类分类。第二我们将在本章后面的一个示例中使用逻辑回归。

支持向量机

支持向量机是一种流行的算法,它的工作原理是找到最大化两个类之间的裕量的最佳超平面,将数据点分成尽可能宽的间隙分成单独的类。最接近分类边界的数据点称为支持向量(图 3-1)。

 图 3-1找到使两个类之间的裕量最大化的最佳超平面 

朴素贝叶斯

朴素贝叶斯是一种基于贝叶斯定理的简单多类线性分类算法。朴素贝叶斯之所以得名,是因为它天真地认为数据集中的特征是独立的,忽略了特征之间任何可能的相关性。我们在本章后面的情感分析示例中使用朴素贝叶斯。

多层感知器

多层感知器是一种前馈人工网络,由几个完全连接的节点层组成。输入图层中的节点与输入数据集相对应。中间层中的节点使用逻辑(sigmoid)函数,而最终输出层中的节点使用 softmax 函数来支持多类分类。输出层中的节点数必须与类数匹配。四我在第7章中讨论了多人感知器。

决策树

决策树通过学习从输入变量推断的决策规则来预测输出变量的值。

从视觉上看,决策树看起来像一棵倒置的树,根节点位于顶部。每个内部节点都表示对属性的测试。叶节点表示类标签,而单个分支表示测试结果。图 3-2 显示了用于预测信用风险的决策树。

 图3-2预测信用风险的决策树
决策树对功能空间执行递归二进制拆分。为了最大限度地提高信息增益,从一组可能的分裂中选择产生最大杂质减少的分裂。信息增益的计算方法是从父节点的杂质中减去子节点杂质的加权和。子节点的杂质越低,信息增益越大。拆分一直持续到达到最大树深度(由 maxDepth 参数设置)、无法再实现大于 minInfoGain 的信息增益,或者 minInstancesPerNode 等于每个子节点生成的训练实例。

有两个用于分类的杂质度量(基尼杂质和熵)和一个用于回归(方差)的杂质度量。对于分类,火花MLlib中杂质的默认度量是基尼杂质。基尼系数是量化节点纯度的指标。如果基尼系数等于零(节点为纯),则节点内存在一类数据。如果基尼系数大于零,则表示节点包含属于不同类的数据。

决策树易于解释。与逻辑回归等线性模型相比,决策树不需要特征缩放。它能够处理缺失的特征,并处理连续特征和分类特征。v单热编码分类功能六不是必需的,实际上在使用决策树和基于树的融合时不鼓励这样做。单热编码会创建不平衡的树,并要求树生长得非常深才能实现良好的预测性能。对于高基数分类特征尤其如此。

不利的一面是,决策树对数据中的噪声很敏感,并且有过度拟合的趋势。由于此限制,决策树本身很少在实际生产环境中使用。如今,决策树作为更强大的集成算法(如随机森林和梯度提升树)的基础模型。

随机森林

随机森林是一种集成算法,它使用一组决策树进行分类和回归。它使用一种称为装袋(或自举聚合)的方法来减少方差,同时保持低偏差。装袋训练来自训练数据子集的单个树。除了装袋之外,随机森林还使用另一种称为特征装袋的方法。与装袋(使用观测值的子集)相比,特征装袋使用特征(列)的子集。功能装袋旨在减少决策树之间的相关性。如果没有特征袋,单个树将非常相似,特别是在只有少数主要特征的情况下。

对于分类,单个树的输出或模式的多数投票将成为模型的最终预测。对于回归,单个树的输出的平均值成为最终输出(图 3-3)。Spark 并行训练几棵树,因为每棵树在随机森林中都是独立训练的。我将在本章后面更详细地讨论随机森林。

 图 3-3用于分类的随机森林


梯度提升树

梯度提升树 (GBT) 是另一种基于树的集成算法,类似于随机森林。GBT使用一种称为提升的技术从弱学习者(浅树)中创建强大的学习者。GBT 按顺序训练决策树的集合七每个后续树都会减少前一个树的误差。这是通过使用上一个模型的残差来拟合下一个模型来完成的。八此残余校正过程九使用交叉验证确定的迭代次数执行一组迭代次数,直到残差完全最小化。

 图 3-4GBT 中的决策树集成
图3-4显示了决策树集合体在GBT中的工作原理,使用我们的信用风险示例,个人根据其信誉分为不同的叶子。诊断树中的每个叶子都分配有一个分数。将多棵树的分数相加,得到最终的预测分数。例如,图3-4显示了第一个决策树,该决策树为女性分配了3分。第二棵树给她打了2分。将两个分数相加,该女子的最终得分为5分。请注意,决策树是相辅相成的。这是GBT的主要原则之一。将分数与每个叶子相关联可为 GBT 提供集成的优化方法。

随机森林与梯度提升树

由于梯度提升树是按顺序训练的,因此通常认为它比随机森林慢且可扩展性更差,随机森林能够并行训练多个树。然而,与随机森林相比,GBT通常使用较浅的树,这意味着GBT可以更快地训练。

增加GBT中的树的数量会增加过度拟合的机会(GBT通过利用更多的树木来减少偏差),而增加随机森林中的树木数量会降低过度拟合的机会(随机森林通过利用更多的树木来减少方差)。一般来说,添加更多的树可以提高随机森林的性能,而当树的数量开始变得太大时,GBT的性能将开始下降。西因此,GBT可能比随机森林更难调整。

如果参数调整正确,梯度提升树通常被认为比随机森林更强大。GBT添加了新的决策树,补充了先前构建的决策树,与随机森林相比,更少的树可以提高预测准确性。十二

近年来开发的用于分类和回归的大多数较新的算法,如XGBoost和光GBM,都是GBT的改进变体。它们没有传统GBT的限制。

第三方分类和回归算法

无数开源贡献者投入时间和精力为Spark开发第三方机器学习算法。虽然它们不是核心的Spark MLlib库的一部分,但数据砖(XGBoost)和微软(LightGBM)等公司已经为这些项目提供了支持,并在全球范围内广泛使用。XGBoost 和 LightGBM 目前被认为是用于分类和回归的下一代机器学习算法。它们是准确性和速度至关重要的情况下的首选算法。我将在本章的后面部分讨论这两个问题。现在,让我们亲自动手,并深入研究一些示例。

使用逻辑回归进行多类分类

逻辑回归是预测概率的线性分类器。它因其易用性和快速的训练速度而广受欢迎,并且经常用于二元分类和多类分类。当数据具有明确的决策边界时,逻辑回归等线性分类器是合适的,如图 3-5 的第一个图表所示。在类不是线性可分离的情况下(如第二个图表所示),应考虑非线性分类器,例如基于树的融合。

 图3-5线性与非线性分类问题

我们将使用流行的 Iris 数据集为第一个示例处理多类分类问题(参见清单 3-1)。数据集包含三个类,每个类有 50 个实例,其中每个类都引用各种鸢尾花植物(鸢尾花濑、鸢尾花和鸢尾花维吉尼卡)。从图 3-6 中可以看出,鸢尾花濑托萨与鸢尾花和鸢尾花维吉尼卡是线性可分离的,但鸢尾花和鸢尾花维吉尼亚不是线性可分离的。逻辑回归在对数据集进行分类方面仍然应该做得很好。

 图3-6虹膜数据集的主成分分析投影
我们的目标是在给定一组特征的情况下预测鸢尾花植物的类型。数据集包含四个数值要素:sepal_length、sepal_width、petal_length和petal_width(均以厘米为单位)。
为我们的数据创建架构。

import org.apache.spark.sql.types._
var irisSchema = StructType(Array (
    StructField("sepal_length",   DoubleType, true),
    StructField("sepal_width",   DoubleType, true),
    StructField("petal_length",   DoubleType, true),
    StructField("petal_width",   DoubleType, true),
    StructField("class",  StringType, true)
    ))

读取 CSV 文件。使用我们刚刚定义的架构。

val dataDF = spark.read.format("csv")
             .option("header","false")
             .schema(irisSchema)
             .load("/files/iris.data")

检查架构。

dataDF.printSchema
root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- class: string (nullable = true)


检查数据以确保它们的格式正确。

dataDF.show
+------------+-----------+------------+-----------+-----------+
|sepal_length|sepal_width|petal_length|petal_width|      class|
+------------+-----------+------------+-----------+-----------+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|
|         4.7|        3.2|         1.3|        0.2|Iris-setosa|
|         4.6|        3.1|         1.5|        0.2|Iris-setosa|
|         5.0|        3.6|         1.4|        0.2|Iris-setosa|
|         5.4|        3.9|         1.7|        0.4|Iris-setosa|
|         4.6|        3.4|         1.4|        0.3|Iris-setosa|
|         5.0|        3.4|         1.5|        0.2|Iris-setosa|
|         4.4|        2.9|         1.4|        0.2|Iris-setosa|
|         4.9|        3.1|         1.5|        0.1|Iris-setosa|
|         5.4|        3.7|         1.5|        0.2|Iris-setosa|
|         4.8|        3.4|         1.6|        0.2|Iris-setosa|
|         4.8|        3.0|         1.4|        0.1|Iris-setosa|
|         4.3|        3.0|         1.1|        0.1|Iris-setosa|
|         5.8|        4.0|         1.2|        0.2|Iris-setosa|
|         5.7|        4.4|         1.5|        0.4|Iris-setosa|
|         5.4|        3.9|         1.3|        0.4|Iris-setosa|
|         5.1|        3.5|         1.4|        0.3|Iris-setosa|
|         5.7|        3.8|         1.7|        0.3|Iris-setosa|
|         5.1|        3.8|         1.5|        0.3|Iris-setosa|
+------------+-----------+------------+-----------+-----------+
only showing top 20 rows

计算数据的汇总统计数据。这可以
有助于了解数据的分布。

dataDF.describe().show(5,15)
+-------+---------------+---------------+---------------+---------------+
|summary|   sepal_length|    sepal_width|   petal_length|    petal_width|
+-------+---------------+---------------+---------------+---------------+
|  count|            150|            150|            150|            150|
|   mean|5.8433333333...|3.0540000000...|3.7586666666...|1.1986666666...|
| stddev|0.8280661279...|0.4335943113...|1.7644204199...|0.7631607417...|
|    min|            4.3|            2.0|            1.0|            0.1|
|    max|            7.9|            4.4|            6.9|            2.5|
+-------+---------------+---------------+---------------+---------------+
+--------------+
|         class|
+--------------+
|           150|
|          null|
|          null|
|   Iris-setosa|
|Iris-virginica|
+--------------+

输入列类当前为字符串。我们将使用
字符串索引器将其编码为双精度。新价值
将存储在名为 label 的新输出列中。

import org.apache.spark.ml.feature.StringIndexer
val labelIndexer = new StringIndexer()
                  .setInputCol("class")
                  .setOutputCol("label")
val dataDF2 = labelIndexer
             .fit(dataDF)
             .transform(dataDF)

检查新数据帧的架构。

dataDF2.printSchema
root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- class: string (nullable = true)
 |-- label: double (nullable = false)

检查添加到数据帧的新列。

dataDF2.show
+------------+-----------+------------+-----------+-----------+-----+
|sepal_length|sepal_width|petal_length|petal_width|      class|label|
+------------+-----------+------------+-----------+-----------+-----+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|  0.0|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|  0.0|
|         4.7|        3.2|         1.3|        0.2|Iris-setosa|  0.0|
|         4.6|        3.1|         1.5|        0.2|Iris-setosa|  0.0|
|         5.0|        3.6|         1.4|        0.2|Iris-setosa|  0.0|
|         5.4|        3.9|         1.7|        0.4|Iris-setosa|  0.0|
|         4.6|        3.4|         1.4|        0.3|Iris-setosa|  0.0|
|         5.0|        3.4|         1.5|        0.2|Iris-setosa|  0.0|
|         4.4|        2.9|         1.4|        0.2|Iris-setosa|  0.0|
|         4.9|        3.1|         1.5|        0.1|Iris-setosa|  0.0|
|         5.4|        3.7|         1.5|        0.2|Iris-setosa|  0.0|
|         4.8|        3.4|         1.6|        0.2|Iris-setosa|  0.0|
|         4.8|        3.0|         1.4|        0.1|Iris-setosa|  0.0|
|         4.3|        3.0|         1.1|        0.1|Iris-setosa|  0.0|
|         5.8|        4.0|         1.2|        0.2|Iris-setosa|  0.0|
|         5.7|        4.4|         1.5|        0.4|Iris-setosa|  0.0|
|         5.4|        3.9|         1.3|        0.4|Iris-setosa|  0.0|
|         5.1|        3.5|         1.4|        0.3|Iris-setosa|  0.0|
|         5.7|        3.8|         1.7|        0.3|Iris-setosa|  0.0|
|         5.1|        3.8|         1.5|        0.3|Iris-setosa|  0.0|
+------------+-----------+------------+-----------+-----------+-----+
only showing top 20 rows

将要素合并到单个向量中
使用矢量组件变压器的列。

import org.apache.spark.ml.feature.VectorAssembler
val features = Array("sepal_length","sepal_width","petal_length","petal_width")
val assembler = new VectorAssembler()
                .setInputCols(features)
                .setOutputCol("features")
val dataDF3 = assembler.transform(dataDF2)

检查添加到数据帧的新列。

dataDF3.printSchema
root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- class: string (nullable = true)
 |-- label: double (nullable = false)
 |-- features: vector (nullable = true)

检查添加到数据帧的新列。

dataDF3.show
+------------+-----------+------------+-----------+-----------+-----+
|sepal_length|sepal_width|petal_length|petal_width|      class|label|
+------------+-----------+------------+-----------+-----------+-----+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|  0.0|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|  0.0|
|         4.7|        3.2|         1.3|        0.2|Iris-setosa|  0.0|
|         4.6|        3.1|         1.5|        0.2|Iris-setosa|  0.0|
|         5.0|        3.6|         1.4|        0.2|Iris-setosa|  0.0|
|         5.4|        3.9|         1.7|        0.4|Iris-setosa|  0.0|
|         4.6|        3.4|         1.4|        0.3|Iris-setosa|  0.0|
|         5.0|        3.4|         1.5|        0.2|Iris-setosa|  0.0|
|         4.4|        2.9|         1.4|        0.2|Iris-setosa|  0.0|
|         4.9|        3.1|         1.5|        0.1|Iris-setosa|  0.0|
|         5.4|        3.7|         1.5|        0.2|Iris-setosa|  0.0|
|         4.8|        3.4|         1.6|        0.2|Iris-setosa|  0.0|
|         4.8|        3.0|         1.4|        0.1|Iris-setosa|  0.0|
|         4.3|        3.0|         1.1|        0.1|Iris-setosa|  0.0|
|         5.8|        4.0|         1.2|        0.2|Iris-setosa|  0.0|
|         5.7|        4.4|         1.5|        0.4|Iris-setosa|  0.0|
|         5.4|        3.9|         1.3|        0.4|Iris-setosa|  0.0|
|         5.1|        3.5|         1.4|        0.3|Iris-setosa|  0.0|
|         5.7|        3.8|         1.7|        0.3|Iris-setosa|  0.0|
|         5.1|        3.8|         1.5|        0.3|Iris-setosa|  0.0|
+------------+-----------+------------+-----------+-----------+-----+
+-----------------+
|         features|
+-----------------+
|[5.1,3.5,1.4,0.2]|
|[4.9,3.0,1.4,0.2]|
|[4.7,3.2,1.3,0.2]|
|[4.6,3.1,1.5,0.2]|
|[5.0,3.6,1.4,0.2]|
|[5.4,3.9,1.7,0.4]|
|[4.6,3.4,1.4,0.3]|
|[5.0,3.4,1.5,0.2]|
|[4.4,2.9,1.4,0.2]|
|[4.9,3.1,1.5,0.1]|
|[5.4,3.7,1.5,0.2]|
|[4.8,3.4,1.6,0.2]|
|[4.8,3.0,1.4,0.1]|
|[4.3,3.0,1.1,0.1]|
|[5.8,4.0,1.2,0.2]|
|[5.7,4.4,1.5,0.4]|
|[5.4,3.9,1.3,0.4]|
|[5.1,3.5,1.4,0.3]|
|[5.7,3.8,1.7,0.3]|
|[5.1,3.8,1.5,0.3]|
+-----------------+
only showing top 20 rows

让我们来衡量统计依赖关系
使用皮尔逊相关性的特征和类。

dataDF3.stat.corr("petal_length","label")
res48: Double = 0.9490425448523336
dataDF3.stat.corr("petal_width","label")
res49: Double = 0.9564638238016178
dataDF3.stat.corr("sepal_length","label")
res50: Double = 0.7825612318100821
dataDF3.stat.corr("sepal_width","label")
res51: Double = -0.41944620026002677

petal_length和petal_width具有极高的类相关性,//而sepal_length和sepal_width具有低类相关性。
如第2章所述,相关性评估两个变量之间的线性//关系的强度。您可以使用关联来选择
相关要素(要素类相关性)并识别冗余要素
特征(特征内相关性)。
将我们的数据集划分为训练数据集和测试数据集。

val seed = 1234
val Array(trainingData, testData) = dataDF3.randomSplit(Array(0.8, 0.2), seed)

现在,我们可以在训练数据集上拟合模型
使用逻辑回归。

import org.apache.spark.ml.classification.LogisticRegression
val lr = new LogisticRegression()

使用我们的训练数据集训练模型。

val model = lr.fit(trainingData)

在我们的测试数据集上进行预测。

val predictions = model.transform(testData)

请注意添加到数据帧的新列:
原始预测,概率,预测。

predictions.printSchema
root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- class: string (nullable = true)
 |-- label: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)

检查预测。

predictions.select("sepal_length","sepal_width",
"petal_length","petal_width","label","prediction").show
+------------+-----------+------------+-----------+-----+----------+
|sepal_length|sepal_width|petal_length|petal_width|label|prediction|
+------------+-----------+------------+-----------+-----+----------+
|         4.3|        3.0|         1.1|        0.1|  0.0|       0.0|
|         4.4|        2.9|         1.4|        0.2|  0.0|       0.0|
|         4.4|        3.0|         1.3|        0.2|  0.0|       0.0|
|         4.8|        3.1|         1.6|        0.2|  0.0|       0.0|
|         5.0|        3.3|         1.4|        0.2|  0.0|       0.0|
|         5.0|        3.4|         1.5|        0.2|  0.0|       0.0|
|         5.0|        3.6|         1.4|        0.2|  0.0|       0.0|
|         5.1|        3.4|         1.5|        0.2|  0.0|       0.0|
|         5.2|        2.7|         3.9|        1.4|  1.0|       1.0|
|         5.2|        4.1|         1.5|        0.1|  0.0|       0.0|
|         5.3|        3.7|         1.5|        0.2|  0.0|       0.0|
|         5.6|        2.9|         3.6|        1.3|  1.0|       1.0|
|         5.8|        2.8|         5.1|        2.4|  2.0|       2.0|
|         6.0|        2.2|         4.0|        1.0|  1.0|       1.0|
|         6.0|        2.9|         4.5|        1.5|  1.0|       1.0|
|         6.0|        3.4|         4.5|        1.6|  1.0|       1.0|
|         6.2|        2.8|         4.8|        1.8|  2.0|       2.0|
|         6.2|        2.9|         4.3|        1.3|  1.0|       1.0|
|         6.3|        2.8|         5.1|        1.5|  2.0|       1.0|
|         6.7|        3.1|         5.6|        2.4|  2.0|       2.0|
+------------+-----------+------------+-----------+-----+----------+
only showing top 20 rows

检查原始预测和概率列。

predictions.select("rawPrediction","probability","prediction")
           .show(false)
+------------------------------------------------------------+
|rawPrediction                                               |
+------------------------------------------------------------+
|[-27765.164694901094,17727.78535517628,10037.379339724806]  |
|[-24491.649758932126,13931.526474094646,10560.123284837473] |
|[20141.806983153703,1877.784589255676,-22019.591572409383]  |
|[-46255.06332259462,20994.503038678085,25260.560283916537]  |
|[25095.115980666546,110.99834659454791,-25206.114327261093] |
|[-41011.14350152455,17036.32945903473,23974.814042489823]   |
|[20524.55747106708,1750.139974552606,-22274.697445619684]   |
|[29601.783587714817,-1697.1845083924927,-27904.599079322325]|
|[38919.06696252647,-5453.963471106039,-33465.10349142042]   |
|[-39965.27448934488,17725.41646382807,22239.85802551682]    |
|[-18994.667253235268,12074.709651218403,6919.957602016859]  |
|[-43236.84898013162,18023.80837865029,25213.040601481334]   |
|[-31543.179893646557,16452.928101990834,15090.251791655724] |
|[-21666.087284218,13802.846783092147,7863.24050112584]      |
|[-24107.97243292983,14585.93668397567,9522.035748954155]    |
|[25629.52586174148,-192.40731255107312,-25437.11854919041]  |
|[-14271.522512385294,11041.861803401871,3229.660708983418]  |
|[-16548.06114507441,10139.917257827732,6408.143887246673]   |
|[22598.60355651257,938.4220993796007,-23537.025655892172]   |
|[-40984.78286289556,18297.704445848023,22687.078417047538]  |
+------------------------------------------------------------+
+-------------+----------+
|probability  |prediction|
+-------------+----------+
|[0.0,1.0,0.0]|1.0       |
|[0.0,1.0,0.0]|1.0       |
|[1.0,0.0,0.0]|0.0       |
|[0.0,0.0,1.0]|2.0       |
|[1.0,0.0,0.0]|0.0       |
|[0.0,0.0,1.0]|2.0       |
|[1.0,0.0,0.0]|0.0       |
|[1.0,0.0,0.0]|0.0       |
|[1.0,0.0,0.0]|0.0       |
|[0.0,0.0,1.0]|2.0       |
|[0.0,1.0,0.0]|1.0       |
|[0.0,1.0,0.0]|1.0       |
|[0.0,1.0,0.0]|1.0       |
|[1.0,0.0,0.0]|0.0       |
|[0.0,1.0,0.0]|1.0       |
|[0.0,1.0,0.0]|1.0       |
|[1.0,0.0,0.0]|0.0       |
|[0.0,0.0,1.0]|2.0       |
+-------------+----------+
only showing top 20 rows

评估模型。提供多种评估指标
对于多类分类:f1(默认),精度,
加权精确度和加权召回。
我将在第 2 章中更详细地讨论评估指标。

import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
val evaluator = new MulticlassClassificationEvaluator().setMetricName("f1")
val f1 = evaluator.evaluate(predictions)
f1: Double = 0.958119658119658
val wp = evaluator.setMetricName("weightedPrecision").evaluate(predictions)
wp: Double = 0.9635416666666667
val wr = evaluator.setMetricName("weightedRecall").evaluate(predictions)
wr: Double = 0.9583333333333335
val accuracy = evaluator.setMetricName("accuracy").evaluate(predictions)
accuracy: Double = 0.9583333333333334

清单 3-1使用逻辑回归进行分类
逻辑回归是一种流行的分类算法,由于其速度和简单性,通常用作第一基线算法。对于生产用途,通常首选更高级的基于树的融合,因为它们具有卓越的准确性和捕获数据集中复杂非线性关系的能力。

随机森林的流失预测

Random Forest是一种强大的集成学习算法,基于多个决策树作为基础模型构建,每个决策树都并行地在数据的不同引导子集上进行训练。如前所述,决策树倾向于过度拟合。随机森林通过使用一种称为bagging (引导聚合)的技术来解决过拟合问题,以使用随机选择的数据子集训练每个决策树。装袋可减少模型的方差,有助于避免过度拟合。随机森林在不增加偏差的情况下减小模型的方差。它还执行功能装袋,为每个决策树随机选择特征。特征装袋的目标是减少单个树之间的相关性。

对于分类,最终类别通过多数投票确定。由各个决策树生成的类的模式(最常发生)成为最终类。对于回归,各个决策树的输出的平均值将成为模型的最终输出。

由于随机森林使用决策树作为其基础模型,因此它继承了其大部分品质。它能够处理连续特征和分类特征,不需要特征缩放和单热编码。随机森林在不平衡的数据上也表现良好,因为它的分层性质迫使它们同时处理这两个类。最后,随机森林可以捕获因变量和自变量之间的非线性关系。

随机森林是用于分类和回归的最流行的基于树的集成算法之一,因为它具有可解释性、准确性和灵活性。但是,训练随机森林模型可能是计算密集型的(这使其成为多核和分布式环境(如Hadoop或Spark)并行化的理想选择)。与逻辑回归或朴素贝叶斯等线性模型相比,它需要更多的内存和计算资源。此外,随机森林在文本或基因组数据等高维数据上的表现往往很差。

注意CSIRO生物信息学团队开发了一种高度可扩展的随机森林实现,最初设计用于称为变体Spark RF的高维基因组数据。 变体火花射频可以处理数百万个功能,并在基准测试中显示比 MLlib 的随机森林实现更具可扩展性。有关变体火花射频的更多信息可以在CSIRO的生物信息学网站上找到。ReForeSt是另一个高度可扩展的随机森林实现,由意大利热那亚大学DIBRIS的SmartLab研究实验室开发。ReForeSt 可以处理数百万个要素,并支持随机森林旋转,这是一种新的集成方法,可扩展经典的随机森林算法。

参数

随机森林相对容易调整。正确设置一些重要参数十七通常足以成功使用随机森林。

  • max_depth:指定树的最大深度。为max_depth设置较高的值可以使模型更具表现力,但将其设置得太高可能会增加过拟合的可能性并使模型更加复杂。
  • num_trees:指定要容纳的树的数量。增加树的数量会减少方差,通常会提高准确性。增加树的数量会减慢训练速度。在某个点之外添加更多树可能不会提高准确性。
  • FeatureSubsetStrategy:指定用于在每个节点上进行拆分的特征部分。设置此参数可以提高训练速度。
  • subsamplingRate:指定将选择用于训练每棵树的数据部分。设置此参数可以提高训练速度,并有助于防止过度拟合。将其设置得太低可能会导致拟合不足。

我提供了一些一般准则,但与往常一样,强烈建议执行参数网格搜索以确定这些参数的最佳值。有关随机森林参数的完整列表,请参阅 Spark MLlib 的在线文档。

例子

流失预测是银行、保险公司、电信公司、有线电视运营商和流媒体服务(如 Netflix、Hulu、Spotify 和苹果音乐)的重要分类用例。能够预测更有可能取消其服务订阅的客户的公司可以实施更有效的客户保留策略。留住客户是有价值的。根据一家领先的客户参与分析公司进行的一项研究,客户流失每年给美国企业造成的损失估计为1360亿美元。贝恩公司进行的研究表明,仅将客户保留率提高5%,利润就会增加25%至95%。Lee Resource Inc.提供的另一项统计数据显示,吸引新客户将使公司的成本是保留现有客户的五倍。

对于我们的示例,我们将使用来自加州大学欧文分校机器学习存储库的流行电信公司流失数据集(参见清单 3-2)。这是一个流行的卡格尔数据集并在网上广泛使用。

对于本书中的大多数示例,我将分别执行转换器和估计器(而不是在管道中指定它们),以便您可以看到添加到生成的 DataFrame 中的新列。这将帮助您在完成示例时看到“引擎盖下”发生了什么。
将 CSV 文件加载到数据帧中。

val dataDF = spark.read.format("csv")
             .option("header", "true")
             .option("inferSchema", "true")
             .load("churn_data.txt")

检查架构。

dataDF.printSchema
root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)

选择几列 。

dataDF
.select("state","phone_number","international_plan","total_day_minutes","churned").show
+-----+------------+------------------+-----------------+-------+
|state|phone_number|international_plan|total_day_minutes|churned|
+-----+------------+------------------+-----------------+-------+
|   KS|    382-4657|                no|            265.1|  False|
|   OH|    371-7191|                no|            161.6|  False|
|   NJ|    358-1921|                no|            243.4|  False|
|   OH|    375-9999|               yes|            299.4|  False|
|   OK|    330-6626|               yes|            166.7|  False|
|   AL|    391-8027|               yes|            223.4|  False|
|   MA|    355-9993|                no|            218.2|  False|
|   MO|    329-9001|               yes|            157.0|  False|
|   LA|    335-4719|                no|            184.5|  False|
|   WV|    330-8173|               yes|            258.6|  False|
|   IN|    329-6603|                no|            129.1|   True|
|   RI|    344-9403|                no|            187.7|  False|
|   IA|    363-1107|                no|            128.8|  False|
|   MT|    394-8006|                no|            156.6|  False|
|   IA|    366-9238|                no|            120.7|  False|
|   NY|    351-7269|                no|            332.9|   True|
|   ID|    350-8884|                no|            196.4|  False|
|   VT|    386-2923|                no|            190.7|  False|
|   VA|    356-2992|                no|            189.7|  False|
|   TX|    373-2782|                no|            224.4|  False|
+-----+------------+------------------+-----------------+-------+
only showing top 20 rows
import org.apache.spark.ml.feature.StringIndexer

将字符串列“改动”(“真”、“假”)转换为双精度 (1,0)。

val labelIndexer = new StringIndexer()
                   .setInputCol("churned")
                   .setOutputCol("label")

转换字符串列“international_plan”(“是”,“否”)
加倍 1,0。

val intPlanIndexer = new StringIndexer()
                     .setInputCol("international_plan")
                     .setOutputCol("int_plan")

让我们选择我们的功能。领域知识在功能上至关重要
选择。我认为total_day_minutes和total_day_calls有
对客户流失有一定的影响。这两者大幅下降
指标可能表示客户不需要该服务
任何时间,可能即将取消他们的电话计划。
但是,我不认为phone_number,area_code和国家有任何
预测质量。我们将在后面讨论功能选择
本章。

val features = Array("number_customer_service_calls","total_day_minutes","total_eve_minutes","account_length","number_vmail_messages","total_day_calls","total_day_charge","total_eve_calls","total_eve_charge","total_night_calls","total_intl_calls","total_intl_charge","int_plan")

将给定的列列表合并到单个向量列中
包括训练 ML 模型所需的所有功能。

import org.apache.spark.ml.feature.VectorAssembler
val assembler = new VectorAssembler()
                .setInputCols(features)
                .setOutputCol("features")


将标签列添加到数据帧。

val dataDF2 = labelIndexer
              .fit(dataDF)
              .transform(dataDF)
dataDF2.printSchema
root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)
 |-- label: double (nullable = false)


“True”转换为 1,“False”转换为 0。

dataDF2.select("churned","label").show
+-------+-----+
|churned|label|
+-------+-----+
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|   True|  1.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|   True|  1.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
+-------+-----+
only showing top 20 rows


将int_plan列添加到数据帧。

val dataDF3 = intPlanIndexer.fit(dataDF2).transform(dataDF2)
dataDF3.printSchema
root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)
 |-- label: double (nullable = false)
 |-- int_plan: double (nullable = false)
dataDF3.select("international_plan","int_plan").show
+------------------+--------+
|international_plan|int_plan|
+------------------+--------+
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|               yes|     1.0|
|               yes|     1.0|
|               yes|     1.0|
|                no|     0.0|
|               yes|     1.0|
|                no|     0.0|
|               yes|     1.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
+------------------+--------+
only showing top 20 rows


将要素矢量列添加到数据帧。

val dataDF4 = assembler.transform(dataDF3)
dataDF4.printSchema
root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)
 |-- label: double (nullable = false)
 |-- int_plan: double (nullable = false)
 |-- features: vector (nullable = true)


这些要素已经过矢量化。

dataDF4.select("features").show(false)
+----------------------------------------------------------------------+
|features                                                              |
+----------------------------------------------------------------------+
|[1.0,265.1,197.4,128.0,25.0,110.0,45.07,99.0,16.78,91.0,3.0,2.7,0.0]  |
|[1.0,161.6,195.5,107.0,26.0,123.0,27.47,103.0,16.62,103.0,3.0,3.7,0.0]|
|[0.0,243.4,121.2,137.0,0.0,114.0,41.38,110.0,10.3,104.0,5.0,3.29,0.0] |
|[2.0,299.4,61.9,84.0,0.0,71.0,50.9,88.0,5.26,89.0,7.0,1.78,1.0]       |
|[3.0,166.7,148.3,75.0,0.0,113.0,28.34,122.0,12.61,121.0,3.0,2.73,1.0] |
|[0.0,223.4,220.6,118.0,0.0,98.0,37.98,101.0,18.75,118.0,6.0,1.7,1.0]  |
|[3.0,218.2,348.5,121.0,24.0,88.0,37.09,108.0,29.62,118.0,7.0,2.03,0.0]|
|[0.0,157.0,103.1,147.0,0.0,79.0,26.69,94.0,8.76,96.0,6.0,1.92,1.0]    |
|[1.0,184.5,351.6,117.0,0.0,97.0,31.37,80.0,29.89,90.0,4.0,2.35,0.0]   |
|[0.0,258.6,222.0,141.0,37.0,84.0,43.96,111.0,18.87,97.0,5.0,3.02,1.0] |
|[4.0,129.1,228.5,65.0,0.0,137.0,21.95,83.0,19.42,111.0,6.0,3.43,0.0]  |
|[0.0,187.7,163.4,74.0,0.0,127.0,31.91,148.0,13.89,94.0,5.0,2.46,0.0]  |
|[1.0,128.8,104.9,168.0,0.0,96.0,21.9,71.0,8.92,128.0,2.0,3.02,0.0]    |
|[3.0,156.6,247.6,95.0,0.0,88.0,26.62,75.0,21.05,115.0,5.0,3.32,0.0]   |
|[4.0,120.7,307.2,62.0,0.0,70.0,20.52,76.0,26.11,99.0,6.0,3.54,0.0]    |
|[4.0,332.9,317.8,161.0,0.0,67.0,56.59,97.0,27.01,128.0,9.0,1.46,0.0]  |
|[1.0,196.4,280.9,85.0,27.0,139.0,33.39,90.0,23.88,75.0,4.0,3.73,0.0]  |
|[3.0,190.7,218.2,93.0,0.0,114.0,32.42,111.0,18.55,121.0,3.0,2.19,0.0] |
|[1.0,189.7,212.8,76.0,33.0,66.0,32.25,65.0,18.09,108.0,5.0,2.7,0.0]   |
|[1.0,224.4,159.5,73.0,0.0,90.0,38.15,88.0,13.56,74.0,2.0,3.51,0.0]    |
+----------------------------------------------------------------------+
only showing top 20 rows


将数据拆分为训练数据和测试数据。

val seed = 1234
val Array(trainingData, testData) = dataDF4.randomSplit(Array(0.8, 0.2), seed)
trainingData.count
res13: Long = 4009
testData.count
res14: Long = 991


创建随机森林分类器。

import org.apache.spark.ml.classification.RandomForestClassifier
val rf = new RandomForestClassifier()
        .setFeatureSubsetStrategy("auto")
        .setSeed(seed)


创建二元分类赋值器,并将标签列设置为
用于评估。

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
val evaluator = new BinaryClassificationEvaluator().setLabelCol("label")


创建参数网格。
i

import org.apache.spark.ml.tuning.ParamGridBuilder
val paramGrid = new ParamGridBuilder()
                .addGrid(rf.maxBins, Array(10, 20,30))
                .addGrid(rf.maxDepth, Array(5, 10, 15))
                .addGrid(rf.numTrees, Array(3, 5, 100))
                .addGrid(rf.impurity, Array("gini", "entropy"))
                .build()


创建管道。

import org.apache.spark.ml.Pipeline
val pipeline = new Pipeline().setStages(Array(rf))


创建交叉验证程序。

import org.apache.spark.ml.tuning.CrossValidator
val cv = new CrossValidator()
         .setEstimator(pipeline)
         .setEvaluator(evaluator)
         .setEstimatorParamMaps(paramGrid)
         .setNumFolds(3)


现在,我们可以使用训练数据集拟合模型,选择
模型的最佳参数集。

val model = cv.fit(trainingData)


您现在可以对我们的测试数据进行一些预测。

val predictions = model.transform(testData)


评估模型。

import org.apache.spark.ml.param.ParamMap
val pmap = ParamMap(evaluator.metricName -> "areaUnderROC")
val auc = evaluator.evaluate(predictions, pmap)
auc: Double = 0.9270599683335483


我们的随机森林分类器具有较高的 AUC 分数。测试
数据由991个观测值组成。预计有 92 位客户
以离开服务。

predictions.count
res25: Long = 991
predictions.filter("prediction=1").count
res26: Long = 92
println(s"True Negative: ${predictions.select("*").where("prediction = 0 AND label = 0").count()}  True Positive: ${predictions.select("*").where("prediction = 1 AND label = 1").count()}")
True Negative: 837 True Positive: 81


我们的测试预测了81名客户离开,他们确实离开了,并且还
预测有837名顾客没有离开,谁实际上没有离开。

println(s"False Negative: ${predictions.select("*").where("prediction = 0 AND label = 1").count()} False Positive: ${predictions.select("*").where("prediction = 1 AND label = 0").count()}")
False Negative: 62 False Positive: 11


我们的测试预测了11名客户离开,他们实际上并没有离开,并且
还预测了62名客户没有离开,谁实际上已经离开了。
您可以按 Raw 预测或目标概率对输出进行排序
概率最高的客户。原始预测和概率
为每个预测提供置信度度量。越大
值,模型对其预测的信心越大。

predictions.select("phone_number","RawPrediction","prediction")
           .orderBy($"RawPrediction".asc)
           .show(false)
+------------+--------------------------------------+----------+
|phone_number|RawPrediction                         |prediction|
+------------+--------------------------------------+----------+
| 366-1084   |[15.038138063913935,84.96186193608602]|1.0       |
| 334-6519   |[15.072688486480072,84.9273115135199] |1.0       |
| 359-5574   |[15.276260309388752,84.72373969061123]|1.0       |
| 399-7865   |[15.429722388653014,84.57027761134698]|1.0       |
| 335-2967   |[16.465107279664032,83.53489272033593]|1.0       |
| 345-9140   |[16.53288465159445,83.46711534840551] |1.0       |
| 342-6864   |[16.694165016887318,83.30583498311265]|1.0       |
| 419-1863   |[17.594670105674677,82.4053298943253] |1.0       |
| 384-7176   |[17.92764148018115,82.07235851981882] |1.0       |
| 357-1938   |[18.8550074623437,81.1449925376563]   |1.0       |
| 355-6837   |[19.556608109022648,80.44339189097732]|1.0       |
| 417-1488   |[20.13305147603522,79.86694852396475] |1.0       |
| 394-5489   |[21.05074084178182,78.94925915821818] |1.0       |
| 394-7447   |[21.376663858426735,78.62333614157326]|1.0       |
| 339-6477   |[21.549262081786424,78.45073791821355]|1.0       |
| 406-7844   |[21.92209788389343,78.07790211610656] |1.0       |
| 372-4073   |[22.098599119168263,77.90140088083176]|1.0       |
| 404-4809   |[22.515513847987147,77.48448615201283]|1.0       |
| 347-8659   |[22.66840460762997,77.33159539237005] |1.0       |
| 335-1874   |[23.336632598761128,76.66336740123884]|1.0       |
+------------+--------------------------------------+----------+
only showing top 20 rows

清单 3-2使用随机森林进行流失预测

功能重要性(Feature Importance)

随机森林(和其他基于树的融合)具有内置的特征选择功能,可用于测量数据集中每个特征的重要性(参见清单 3-3)。

Random Forest 将特征重要性计算为每次选择特征拆分节点时聚合的每个树中每个节点的节点杂质减少之和,除以森林中的树数。Spark MLlib 提供了一种方法,该方法返回每个特征重要性的估计值。

import org.apache.spark.ml.classification.RandomForestClassificationModel
import org.apache.spark.ml.PipelineModel
val bestModel = model.bestModel
val model = bestModel
            .asInstanceOf[PipelineModel]
            .stages
            .last
            .asInstanceOf[RandomForestClassificationModel]
model.featureImportances
feature_importances: org.apache.spark.ml.linalg.Vector =
(13,[0,1,2,3,4,5,6,7,8,9,10,11,12],
[
0.20827010117447803,
0.1667170878866465,
0.06099491253318444,
0.008184141410796346,
0.06664053647245761,
0.0072108752126555,
0.21097011684691344,
0.006902059667276019,
0.06831916361401609,
0.00644772968425685,
0.04105403721675372,
0.056954219262186724,
0.09133501901837866])

清单 3-3显示随机森林的特征重要性
我们得到一个向量作为回报,其中包含特征的数量(在我们的示例中为13),特征的数组索引和相应的权重。表 3-1 以更具可读性的格式显示了输出,实际特征以相应的权重显示。如您所见,total_day_charge、total_day_minutes和number_customer_service_calls是最重要的功能。这是有道理的。大量的客户服务呼叫可能表示服务多次中断或大量客户投诉。低total_day_minutes和total_day_charge可能表明客户没有经常使用他的电话计划,这可能意味着他正准备很快取消他的计划。
表 3-1功能对电信客户流失预测示例的重要性

IndexFeatureFeature Importance
0number_customer_service_calls0.20827010117447803
1total_day_minutes0.1667170878866465
2total_eve_minutes0.06099491253318444
3account_length0.008184141410796346
4number_vmail_messages0.06664053647245761
5total_day_calls0.0072108752126555
6total_day_charge0.21097011684691344
7total_eve_calls0.006902059667276019
8total_eve_charge0.06831916361401609
9total_night_calls0.00644772968425685
10total_intl_calls0.04105403721675372
11total_intl_charge0.056954219262186724
12int_plan0.09133501901837866

注意Spark MLlib在随机森林中实现特征重要性也称为基尼重要性或杂质平均减少(MDI)。随机森林的一些实现使用不同的方法来计算特征重要性,称为基于精度的重要性或精度平均下降(MDA)。二十三基于精度的重要性是根据预测精度的降低来计算的,因为特征是随机排列的。尽管 Spark MLlib 的随机森林实现并不直接支持此方法,但通过评估模型同时一次一列地排列每个特征的值来手动实现是相当简单的。

检查最佳模型使用的参数有时很有用(参见清单 3-4)。

import org.apache.spark.ml.classification.RandomForestClassificationModel
import org.apache.spark.ml.PipelineModel
val bestModel = model
                .bestModel
                .asInstanceOf[PipelineModel]
                .stages
                .last
                .asInstanceOf[RandomForestClassificationModel]
 print(bestModel.extractParamMap)
{
        rfc_81c4d3786152-cacheNodeIds: false,
        rfc_81c4d3786152-checkpointInterval: 10,
        rfc_81c4d3786152-featureSubsetStrategy: auto,
        rfc_81c4d3786152-featuresCol: features,
        rfc_81c4d3786152-impurity: gini,
        rfc_81c4d3786152-labelCol: label,
        rfc_81c4d3786152-maxBins: 10,
        rfc_81c4d3786152-maxDepth: 15,
        rfc_81c4d3786152-maxMemoryInMB: 256,
        rfc_81c4d3786152-minInfoGain: 0.0,
        rfc_81c4d3786152-minInstancesPerNode: 1,
        rfc_81c4d3786152-numTrees: 100,
        rfc_81c4d3786152-predictionCol: prediction,
        rfc_81c4d3786152-probabilityCol: probability,
        rfc_81c4d3786152-rawPredictionCol: rawPrediction,
        rfc_81c4d3786152-seed: 1234,
        rfc_81c4d3786152-subsamplingRate: 1.0
}


清单 3-4提取随机森林模型的参数

极限梯度提升与 XGBoost4J-Spark

梯度提升算法是用于分类和回归的最强大的机器学习算法。梯度提升算法目前有各种实现。流行的实现包括AdaBoost和猫推器(最近开源的来自Yandex的梯度提升库)。火花 MLlib 还包括自己的梯度提升树 (GBT) 实现。

XGBoost(极限梯度提升)是目前市面上最好的梯度提升树实现之一。XGBoost于2014年3月27日由陈天奇作为研究项目发布,已成为分类和回归的主导机器学习算法。专为提高效率和可扩展性而设计,其并行树增强功能使其比其他基于树的集成算法快得多。由于其高精度,XGBoost通过赢得几项机器学习竞赛而广受欢迎。2015 年,Kaggle 上的 29 个获奖解决方案中有 17 个使用了 XGBoost。2015年KDD杯上的所有十大解决方案二十四使用过的 XG 助推器。

XGBoost是使用梯度提升的一般原理设计的,将弱学习者组合成强学习者。但是,虽然梯度提升树是按顺序构建的 - 慢慢地从数据中学习以改善其在后续迭代中的预测,但XGBoost并行构建树。XGBoost 通过其内置的正则化来控制模型复杂性并减少过拟合,从而产生更好的预测性能。它使用近似算法在为连续要素查找最佳分割点时查找分割点。二十五

近似拆分方法使用离散条柱来存储连续特征,从而显著加快模型训练速度。XGBoost 包括另一种使用基于直方图的算法的树生长方法,该方法提供了一种将连续要素划分为离散条柱的更有效方法。但是,虽然近似方法在每次迭代时创建一组新的条柱,但基于直方图的方法在多次迭代中重用条柱。

此方法允许使用近似方法无法实现的其他优化,例如缓存箱以及父直方图和同级直方图减法的功能。二十六为了优化排序操作,XGBoost 将排序数据存储在块的内存单元中。排序块可以由并行 CPU 内核有效地分发和执行。XGBoost可以通过其加权分位数草图算法有效地处理加权数据,可以有效地处理稀疏数据,具有缓存感知能力,并通过将磁盘空间用于大型数据集来支持非核心计算,因此数据不必适合内存。

XGBoost4J-Spark 项目于 2016 年底启动,将 XGBoost 移植到火花。XGBoost4J-Spark 利用了 Spark 高度可扩展的分布式处理引擎,并与 Spark MLlib 的数据帧/数据集抽象完全兼容。XGBoost4J-Spark可以无缝嵌入到火花MLlib管道中,并与火花MLlib的变压器和估计器集成。

注意XGBoost4J-Spark 需要阿帕奇火花 2.4+。建议直接从 http://spark.apache.org 安装 Spark。XGBoost4J-Spark 不保证能够与来自其他供应商(如云端、霍顿沃克斯或 MapR)的第三方火花发行版配合使用。有关详细信息,请参阅供应商的文档。

参数

XGBoost 的参数比随机森林多得多,通常需要更多的调整。最初专注于最重要的参数可以让你开始使用 XGBoost。随着您对算法的习惯,您可以学习其余部分。

  • max_depth:指定树的最大深度。为max_depth设置较高的值可能会增加过度拟合的可能性,并使模型更加复杂。
  • n_estimators:指定要容纳的树的数量。一般来说,价值越大越好。将此参数设置得太高可能会影响训练速度。在某个点之外添加更多树可能不会提高准确性。默认值设置为 100。二十八
  • sub_sample:指定将为每个树选择的数据部分。设置此参数可以提高训练速度并有助于防止过度拟合。将其设置得太低可能会导致拟合不足。
  • colsample_bytree:指定将为每个树随机选择的列的分数。设置此参数可以提高训练速度并有助于防止过度拟合。相关参数包括colsample_bylevel和colsample_bynode。
  • 目的:指定学习任务和学习目标。为此参数设置正确的值非常重要,以避免出现不可预知的结果或较差的准确性。对于二进制分类,XGB 分类器默认为二进制:逻辑,而 XGB 分类器默认为 reg:平方。其他值包括多类分类的多:软最大和多:软探测器;排名:成对,排名:ndcg,排名:地图进行排名;和生存:cox使用Cox比例风险模型进行生存回归,仅举几例。
  • learning_rate (eta):learning_rate用作收缩因子,以在每个提升步骤后降低特征权重,目的是减慢学习速度。此参数用于控制过拟合。较低的值需要更多的树。
  • n_jobs:指定 XGBoost 使用的并行线程数(如果n_thread已弃用,请改用此参数)。

这些只是有关如何使用参数的一般准则。强烈建议执行参数网格搜索以确定这些参数的最佳值。有关 XG 升级参数的完整列表,请参阅 XGBoost 的在线文档。

注意为了与 Scala 的变量命名约定保持一致,XGBoost4J-Spark 既支持默认参数集,也支持这些参数的驼峰大小写变体(例如,max_depth和最大深度)。

例子

我们将重用相同的电信公司流失数据集和上一个随机森林示例中的大部分代码(参见清单 3-5)。这一次,我们将使用流水线将变压器和估算器连接在一起。
XGBoost4J-火花可作为外部封装提供。

spark-shell --packages ml.dmlc:xgboost4j-spark:0.81

将 CSV 文件加载到数据帧中。

val dataDF = spark.read.format("csv")
             .option("header", "true")
             .option("inferSchema", "true")
             .load("churn_data.txt")

检查架构。

dataDF.printSchema
root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)


选择几列。

dataDF.select("state","phone_number","international_plan","churned").show
+-----+------------+------------------+-------+
|state|phone_number|international_plan|churned|
+-----+------------+------------------+-------+
|   KS|    382-4657|                no|  False|
|   OH|    371-7191|                no|  False|
|   NJ|    358-1921|                no|  False|
|   OH|    375-9999|               yes|  False|
|   OK|    330-6626|               yes|  False|
|   AL|    391-8027|               yes|  False|
|   MA|    355-9993|                no|  False|
|   MO|    329-9001|               yes|  False|
|   LA|    335-4719|                no|  False|
|   WV|    330-8173|               yes|  False|
|   IN|    329-6603|                no|   True|
|   RI|    344-9403|                no|  False|
|   IA|    363-1107|                no|  False|
|   MT|    394-8006|                no|  False|
|   IA|    366-9238|                no|  False|
|   NY|    351-7269|                no|   True|
|   ID|    350-8884|                no|  False|
|   VT|    386-2923|                no|  False|
|   VA|    356-2992|                no|  False|
|   TX|    373-2782|                no|  False|
+-----+------------+------------------+-------+
only showing top 20 rows
import org.apache.spark.ml.feature.StringIndexer


将字符串“改动”列(“真”、“假”)转换为双精度(1,0)。

val labelIndexer = new StringIndexer()
                   .setInputCol("churned")
                   .setOutputCol("label")


将字符串“international_plan”(“否”、“是”)列转换为 double(1,0)。

val intPlanIndexer = new StringIndexer()
                     .setInputCol("international_plan")
                     .setOutputCol("int_plan")


指定要为模型拟合选取的特征。

val features = Array(“number_customer_service_calls”,“total_day_minutes”,“total_eve_minutes”,“account_length”,“number_vmail_messages”,“total_day_calls”,“total_day_charge”,“total_eve_calls”,“total_eve_charge”,“total_night_calls”,“total_intl_calls”,“total_intl_charge”,“int_plan”)

将要素合并到单个矢量列中。

import org.apache.spark.ml.feature.VectorAssembler
val assembler = new VectorAssembler()
                .setInputCols(features)
                .setOutputCol("features")


将数据拆分为训练数据和测试数据。

val seed = 1234
val Array(trainingData, testData) = dataDF.randomSplit(Array(0.8, 0.2), seed)


创建一个 XG 助推器分类器。

import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel
val xgb = new XGBoostClassifier()
          .setFeaturesCol("features")
          .setLabelCol("label")


XGB分类器的目标参数默认为二进制:逻辑
是我们想要的学习任务和目标
(二元分类)。根据您的任务,请记住设置
正确的学习任务和目标。

import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
val evaluator = new MulticlassClassificationEvaluator().setLabelCol("label")
import org.apache.spark.ml.tuning.ParamGridBuilder
val paramGrid = new ParamGridBuilder()
                .addGrid(xgb.maxDepth, Array(3, 8))
                .addGrid(xgb.eta, Array(0.2, 0.6))
                .build()


这次,我们将指定管道中的所有步骤。

import org.apache.spark.ml.{ Pipeline, PipelineStage }
val pipeline = new Pipeline()
               .setStages(Array(labelIndexer, intPlanIndexer, assembler, xgb))


创建交叉验证程序。

import org.apache.spark.ml.tuning.CrossValidator
val cv = new CrossValidator()
         .setEstimator(pipeline)
         .setEvaluator(evaluator)
         .setEstimatorParamMaps(paramGrid)
         .setNumFolds(3)

现在,我们可以使用训练数据拟合模型。这将运行
交叉验证,选择最佳参数集。

val model = cv.fit(trainingData)

您现在可以对我们的测试数据进行一些预测。

val predictions = model.transform(testData)
predictions.printSchema
root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)
 |-- label: double (nullable = false)
 |-- int_plan: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)

让我们评估模型。

val auc = evaluator.evaluate(predictions)
auc: Double = 0.9328044307445879

与XGBoost4J-火花产生的AUC分数相比略好一些
到我们之前的随机森林示例。XG助推器4J-火花也是
在训练此数据集时比随机森林更快。
与随机森林一样,XGBoost 允许您提取特征重要性。

import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel
import org.apache.spark.ml.PipelineModel
val bestModel = model.bestModel
val model = bestModel
            .asInstanceOf[PipelineModel]
            .stages
            .last
            .asInstanceOf[XGBoostClassificationModel]

执行获取功能分数方法以提取功能重要性。

model.nativeBooster.getFeatureScore()
res9: scala.collection.mutable.Map[String,Integer] = Map(f7 -> 4, f9 -> 7, f10 -> 2, f12 -> 4, f11 -> 8, f0 -> 5, f1 -> 19, f2 -> 17, f3 -> 10, f4 -> 2, f5 -> 3)

该方法返回一个映射,其中包含到功能的键映射
数组索引和对应于特征重要性分数的值。
清单 3-5使用 XGBoost4J-火花进行改动预测
表 3-2使用 XGBoost4J-火花的功能重要性

IndexFeatureFeature Importance
0number_customer_service_calls2
1total_day_minutes15
2total_eve_minutes10
3account_length3
4number_vmail_messages2
5total_day_calls3
6total_day_chargeOmitted
7total_eve_calls2
8total_eve_charge

Omitted
9total_night_calls2
10total_intl_calls2
11total_intl_charge1
12int_plan5

请注意,输出缺少几列,特别是total_day_charge (f6) 和 total_eve_charge (f8)。XGBoost 认为这些特征在提高模型的预测准确性方面是无效的(参见表 3-2)。只有在至少一次拆分中使用的要素才能进入 XGBoost 要素重要性输出。有几种可能的解释。这可能意味着丢弃的要素具有非常低或零方差。这也可能意味着这两个特征与其他特征高度相关。

在将 XGBoost 的特征重要性输出与我们之前的随机森林示例进行比较时,有一些有趣的事情需要注意。请注意,虽然我们之前的随机森林模型将number_customer_service_calls视为最重要的功能之一,但 XGBoost 将其列为最不重要的功能之一。类似地,之前的随机森林模型认为total_day_charge是最重要的特征,但 XGBoost 由于它的重要性不足,从输出中完全省略了它(参见清单 3-6)。

val bestModel = model
                .bestModel
                .asInstanceOf[PipelineModel]
                .stages
                .last
                .asInstanceOf[XGBoostClassificationModel]
print(bestModel.extractParamMap)
{
        xgbc_9b95e70ab140-alpha: 0.0,
        xgbc_9b95e70ab140-baseScore: 0.5,
        xgbc_9b95e70ab140-checkpointInterval: -1,
        xgbc_9b95e70ab140-checkpointPath: ,
        xgbc_9b95e70ab140-colsampleBylevel: 1.0,
        xgbc_9b95e70ab140-colsampleBytree: 1.0,
        xgbc_9b95e70ab140-customEval: null,
        xgbc_9b95e70ab140-customObj: null,
        xgbc_9b95e70ab140-eta: 0.2,
        xgbc_9b95e70ab140-evalMetric: error,
        xgbc_9b95e70ab140-featuresCol: features,
        xgbc_9b95e70ab140-gamma: 0.0,
        xgbc_9b95e70ab140-growPolicy: depthwise,
        xgbc_9b95e70ab140-labelCol: label,
        xgbc_9b95e70ab140-lambda: 1.0,
        xgbc_9b95e70ab140-lambdaBias: 0.0,
        xgbc_9b95e70ab140-maxBin: 16,
        xgbc_9b95e70ab140-maxDeltaStep: 0.0,
        xgbc_9b95e70ab140-maxDepth: 8,
        xgbc_9b95e70ab140-minChildWeight: 1.0,
        xgbc_9b95e70ab140-missing: NaN,
        xgbc_9b95e70ab140-normalizeType: tree,
        xgbc_9b95e70ab140-nthread: 1,
        xgbc_9b95e70ab140-numEarlyStoppingRounds: 0,
        xgbc_9b95e70ab140-numRound: 1,
        xgbc_9b95e70ab140-numWorkers: 1,
        xgbc_9b95e70ab140-objective: reg:linear,
        xgbc_9b95e70ab140-predictionCol: prediction,
        xgbc_9b95e70ab140-probabilityCol: probability,
        xgbc_9b95e70ab140-rateDrop: 0.0,
        xgbc_9b95e70ab140-rawPredictionCol: rawPrediction,
        xgbc_9b95e70ab140-sampleType: uniform,
        xgbc_9b95e70ab140-scalePosWeight: 1.0,
        xgbc_9b95e70ab140-seed: 0,
        xgbc_9b95e70ab140-silent: 0,
        xgbc_9b95e70ab140-sketchEps: 0.03,
        xgbc_9b95e70ab140-skipDrop: 0.0,
        xgbc_9b95e70ab140-subsample: 1.0,
        xgbc_9b95e70ab140-timeoutRequestWorkers: 1800000,
        xgbc_9b95e70ab140-trackerConf: TrackerConf(0,python),
        xgbc_9b95e70ab140-trainTestRatio: 1.0,
        xgbc_9b95e70ab140-treeLimit: 0,
        xgbc_9b95e70ab140-treeMethod: auto,
        xgbc_9b95e70ab140-useExternalMemory: false
}

清单 3-6提取 XGBoost4J-火花模型的参数

LightGBM:来自微软的快速梯度提升

多年来,XGBoost一直是每个人最喜欢的分类和回归算法。最近,光GBM成为王位的新挑战者。它是一种相对较新的基于树的梯度提升变体,类似于XGBoost。LightGBM 于 2016 年 10 月 17 日发布,是微软分布式机器学习工具包 (DMTK) 项目的一部分。它被设计为快速和分布式,从而带来更快的训练速度和低内存使用率。它支持GPU和并行学习以及处理大型数据集的能力。LightGBM已经在公共数据集的几个基准测试和实验中证明,它比XGBoost更快,准确性更高。

注意LightGBM 已作为微软机器学习阿帕奇火花 (MMLSpark) 生态系统的一部分移植到 Spark。微软一直在积极开发数据科学和深度学习工具,与微软认知工具包、OpenCV和LightGBM等Apache Spark生态系统无缝集成。MMLSpark 需要蟒蛇 2.7 或 3.5+、斯卡拉 2.11 和火花 2.3+。

与 XG 助推器相比,光 GBM 具有几个优点。它利用直方图将连续要素桶入离散条柱中。与 XGBoost(默认情况下,XGBoost 使用基于预排序的算法进行树学习)相比,这为 LightGBM 提供了多项性能优势,例如减少了内存使用量,降低了计算每次拆分的收益的成本,以及降低了并行学习的通信成本。LightGBM 通过对其同级和父级执行直方图减法来计算节点的直方图,从而实现额外的性能提升。在线基准测试显示,在某些任务中,LightGBM比XGBoost(不分箱)快11到15倍。

光GBM通常通过叶子(最佳优先)种植树木在准确性方面优于XGBoost。训练决策树有两种主要策略,按级别和按叶(如图 3-7 所示)。逐级树生长是大多数基于树的融合(包括 XGBoost)的决策树的传统增长方式。光GBM引入了叶式生长策略。与水平生长相比,叶向生长通常收敛得更快断续器并实现更低的损耗。

注意叶状生长往往与小数据集过度拟合。建议在 LightGBM 中设置max_depth参数以限制树的深度。请注意,即使设置了max_depth,树木仍然会叶子生长。我将在本章的后面部分讨论 LightGBM 参数调整。

 图3-7水平生长与叶生长

注意此后,XGBoost实施了LightGBM开创的许多优化,包括叶状树生长策略和使用直方图将连续特征桶入离散箱中。最新的基准测试显示,XGBoost的性能与光GBM相比具有竞争力。

参数

与其他算法(如随机森林)相比,调整 LightGBM 稍微复杂一些。LightGBM使用逐叶(最佳优先)树生长算法,如果参数配置不正确,该算法可能容易受到过度拟合的影响。此外,光GBM具有100多个参数。专注于最重要的参数足以帮助您开始使用 LightGBM。随着您对算法的熟悉程度越来越高,您可以学习其余部分。

  • max_depth:设置此参数可防止树木生长得太深。浅树过度拟合的可能性较小。如果数据集很小,则设置此参数尤其重要。
  • num_leaves:控制树模型的复杂性。该值应小于 2^(max_depth)以防止过度拟合。将num_leaves设置为较大的值可以提高准确性,但存在更高的过拟合几率。将num_leaves设置为较小的值有助于防止过度拟合。
  • min_data_in_leaf:将此参数设置为较大的值可以防止树长得太深。这是您可以设置的另一个参数,以帮助控制过拟合。将值设置得太大可能会导致拟合不足。
  • max_bin:LightGBM 使用直方图将连续要素的值分组到离散存储桶中。设置max_bin以指定值将分组到其中的条柱数。较小的值可以帮助控制过度拟合并提高训练速度,而较大的值可以提高准确性。
  • feature_fraction:此参数启用特征子采样。此参数指定将在每次迭代中随机选择的要素部分。例如,将 feature_fraction 设置为 0.75 将在每次迭代中随机选择 75% 的特征。设置此参数可以提高训练速度并有助于防止过度拟合。
  • bagging_fraction:指定将在每次迭代中选择的数据部分。例如,将 bagging_fraction 设置为 0.75 将在每次迭代中随机选择 75% 的数据。设置此参数可以提高训练速度并有助于防止过度拟合。
  • num_iteration:设置提升迭代的次数。默认值为 100。对于多类分类,LightGBM 构建num_class * num_iterations树。设置此参数会影响训练速度。
  • objective:与 XG 助推一样,光GBM 支持多个目标。默认目标设置为回归。设置此参数以指定模型尝试执行的任务类型。对于回归任务,选项包括regression_l2、regression_l1、泊松、分位数、地图、伽马、huber、公平或花呢。对于分类任务,选项为二元、多类或多类。正确设定目标以避免不可预知的结果或准确性差非常重要。

与往常一样,强烈建议执行参数网格搜索以确定这些参数的最佳值。有关光 GBM 参数的详尽列表,请参阅光 GBM 联机文档。

注意 在撰写本文时,用于火花的光GBM尚未达到与用于蟒蛇的光GBM的功能奇偶校验。虽然火花的光GBM包括最重要的参数,但它仍然缺少一些参数。您可以通过访问 https://bit.ly/2OqHl2M 来获取火花光GBM中所有可用参数的列表。您可以在 https://bit.ly/30YGyaO 将其与 LightGBM 参数的完整列表进行比较。

例子

我们将重用相同的电信公司流失数据集以及前面的随机森林和 XGBoost 示例中的大部分代码,如清单 3-7 所示。
spark-shell --packages Azure:mmlspark:0.15
将 CSV 文件加载到数据帧中。

val dataDF = spark.read.format("csv")
             .option("header", "true")
             .option("inferSchema", "true")
             .load("churn_data.txt")


检查架构。

dataDF.printSchema
root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)


选择几列。

dataDF.select("state","phone_number","international_plan","churned").show
+-----+------------+------------------+-------+
|state|phone_number|international_plan|churned|
+-----+------------+------------------+-------+
|   KS|    382-4657|                no|  False|
|   OH|    371-7191|                no|  False|
|   NJ|    358-1921|                no|  False|
|   OH|    375-9999|               yes|  False|
|   OK|    330-6626|               yes|  False|
|   AL|    391-8027|               yes|  False|
|   MA|    355-9993|                no|  False|
|   MO|    329-9001|               yes|  False|
|   LA|    335-4719|                no|  False|
|   WV|    330-8173|               yes|  False|
|   IN|    329-6603|                no|   True|
|   RI|    344-9403|                no|  False|
|   IA|    363-1107|                no|  False|
|   MT|    394-8006|                no|  False|
|   IA|    366-9238|                no|  False|
|   NY|    351-7269|                no|   True|
|   ID|    350-8884|                no|  False|
|   VT|    386-2923|                no|  False|
|   VA|    356-2992|                no|  False|
|   TX|    373-2782|                no|  False|
+-----+------------+------------------+-------+
only showing top 20 rows
import org.apache.spark.ml.feature.StringIndexer
val labelIndexer = new StringIndexer().setInputCol("churned").setOutputCol("label")
val intPlanIndexer = new StringIndexer().setInputCol("international_plan").setOutputCol("int_plan")
val features = Array("number_customer_service_calls","total_day_minutes","total_eve_minutes","account_length","number_vmail_messages","total_day_calls","total_day_charge","total_eve_calls","total_eve_charge","total_night_calls","total_intl_calls","total_intl_charge","int_plan")
import org.apache.spark.ml.feature.VectorAssembler
val assembler = new VectorAssembler()
                .setInputCols(features)
                .setOutputCol("features")
val seed = 1234
val Array(trainingData, testData) = dataDF.randomSplit(Array(0.9, 0.1), seed)


创建一个Light GBM 分类器。

import com.microsoft.ml.spark.LightGBMClassifier
val lightgbm = new LightGBMClassifier()
               .setFeaturesCol("features")
               .setLabelCol("label")
               .setRawPredictionCol("rawPrediction")
               .setObjective("binary")

请记住使用 setObject 方法设置正确的目标。
指定不正确的物镜可能会影响准确性或产生
不可预知的结果。在轻量级 GBM 中,默认目标设置为
回归。在此示例中,我们正在执行二元分类,因此
我们将目标设置为二进制。
清单 3-7使用轻 GBM 进行流失预测

注意Spark 支持从版本 2.4 开始的屏障执行模式。LightGBM 支持屏障执行模式,从版本 0.18 开始使用存储执行模式方法。

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
val evaluator = new BinaryClassificationEvaluator()
                .setLabelCol("label")
                .setMetricName("areaUnderROC")
import org.apache.spark.ml.tuning.ParamGridBuilder
val paramGrid = new ParamGridBuilder()
                .addGrid(lightgbm.maxDepth, Array(2, 3, 4))
                .addGrid(lightgbm.numLeaves, Array(4, 6, 8))
                .addGrid(lightgbm.numIterations, Array(600))
                .build()
import org.apache.spark.ml.{ Pipeline, PipelineStage }
val pipeline = new Pipeline()
               .setStages(Array(labelIndexer, intPlanIndexer, assembler, lightgbm))
import org.apache.spark.ml.tuning.CrossValidator
val cv = new CrossValidator()
         .setEstimator(pipeline)
         .setEvaluator(evaluator)
         .setEstimatorParamMaps(paramGrid)
         .setNumFolds(3)
val model = cv.fit(trainingData)

您现在可以对我们的测试数据进行一些预测。

val predictions = model.transform(testData)
predictions.printSchema
root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)
 |-- label: double (nullable = false)
 |-- int_plan: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)


评估模型。AUC 分数高于随机森林
和我们之前的例子中的XGBoost。

val auc = evaluator.evaluate(predictions)
auc: Double = 0.940366124260358


光GBM还允许您提取特征重要性。

import com.microsoft.ml.spark.LightGBMClassificationModel
import org.apache.spark.ml.PipelineModel
val bestModel = model.bestModel
val model = bestModel.asInstanceOf[PipelineModel]
            .stages
            .last
            .asInstanceOf[LightGBMClassificationModel]

LightGBM 中有两种类型的特征重要性,“分割”(分割总数)和“增益”(总信息增益)。通常建议使用“增益”,这与随机森林计算特征重要性的方法大致相似,但LightGBM在我们的二元分类示例中使用交叉熵(对数损失)而不是使用基尼杂质(参见表3-3和3-4)。要最小化的损失取决于指定的目标。

val gainFeatureImportances = model.getFeatureImportances("gain")
gainFeatureImportances: Array[Double] =
Array(2648.0893859118223, 5339.0795262902975, 2191.832309693098,564.6461282968521, 1180.4672759771347, 656.8244850635529, 0.0, 533.6638155579567, 579.7435692846775, 651.5408382415771, 1179.492751300335, 2186.5995585918427, 1773.7864662855864)

表 3-3使用信息增益的 LightGBM 的特征重要性

IndexFeatureFeature Importance
0number_customer_service_calls2648.0893859118223
1total_day_minutes5339.0795262902975
2total_eve_minutes2191.832309693098
3account_length564.6461282968521
4number_vmail_messages1180.4672759771347
5total_day_calls656.8244850635529
6total_day_charge0.0
7total_eve_calls533.6638155579567
8total_eve_charge579.7435692846775
9total_night_calls651.5408382415771
10total_intl_calls1179.492751300335
11total_intl_charge2186.5995585918427
12int_plan1773.7864662855864

比较使用“拆分”时的输出。

val gainFeatureImportances = model.getFeatureImportances("split")
gainFeatureImportances: Array[Double] = Array(159.0, 583.0, 421.0, 259.0, 133.0, 264.0, 0.0, 214.0, 92.0, 279.0, 279.0, 366.0, 58.0)


表 3-4使用拆分次数的 LightGBM 的特征重要性

IndexFeatureFeature Importance
0number_customer_service_calls159.0
1total_day_minutes583.0
2total_eve_minutes421.0
3account_length259.0
4number_vmail_messages133.0
5total_day_calls264.0
6total_day_charge0.0
7total_eve_calls214.0
8total_eve_charge92.0
9total_night_calls279.0
10total_intl_calls279.0
11total_intl_charge366.0
12int_plan58.0
println(s"True Negative: ${predictions.select("*").where("prediction = 0 AND label = 0").count()}  True Positive: ${predictions.select("*").where("prediction = 1 AND label = 1").count()}")
True Negative: 407  True Positive: 58
println(s"False Negative: ${predictions.select("*").where("prediction = 0 AND label = 1").count()} False Positive: ${predictions.select("*").where("prediction = 1 AND label = 0").count()}")
False Negative: 20 False Positive: 9

使用朴素贝叶斯进行情感分析

朴素贝叶斯是一种基于贝叶斯定理的简单多类线性分类算法。朴素贝叶斯之所以得名,是因为它天真地认为数据集中的特征是独立的,忽略了特征之间任何可能的相关性。在实际场景中并非如此,朴素贝叶斯仍然倾向于表现良好,特别是在小型数据集或具有高维的数据集上。与线性分类器一样,它在非线性分类问题上表现不佳。朴素贝叶斯是一种计算效率高且高度可扩展的算法,只需要对数据集进行一次传递。它是使用大型数据集进行分类任务的良好基线模型。它的工作原理是找到给定一组特征的点属于类的概率。贝叶斯定理方程可以表述为:

 P(A|B)是后验概率,可以解释为:“给定事件B,事件A发生的概率是多少?”,B代表特征向量。分子表示条件概率乘以先验概率。分母代表证据。该等式可以更准确地写为:

 朴素贝叶斯经常用于文本分类。文本分类的常用应用包括垃圾邮件检测和文档分类。另一个文本分类用例是情感分析。公司定期检查来自社交媒体的评论,以确定产品或服务的公众舆论是积极的还是消极的。对冲基金使用情绪分析来预测股市走势。

Spark MLlib 支持伯努利朴素贝叶斯和多项式朴素贝叶斯。伯努利朴素贝叶斯仅适用于布尔或二进制特征(例如,文档中是否存在单词),而多项朴素贝叶斯则设计用于离散特征(例如,字数统计)。MLlib 朴素贝叶斯实现的默认模型类型设置为多项式。您可以设置另一个参数 lambda 进行平滑处理(默认值为 1.0)。

例子

让我们通过一个示例来演示如何使用朴素贝叶斯进行情感分析。我们将使用来自加州大学欧文分校机器学习存储库的流行数据集。该数据集是为论文“使用深层特征从组到单个标签”创建的,Kotzias等人。2015年。该数据集来自三家不同的公司:IMDB,亚马逊和Yelp。每家公司有500条正面评论和500条负面评论。我们将使用来自亚马逊的数据集,根据亚马逊产品评论确定特定商品的情绪为正面 (1) 或负面 (0) 的概率。

我们需要将数据集中的每个句子转换为特征向量。火花MLlib正是为此提供了一个变压器。术语频率-反向文档频率 (TF–IDF) 通常用于从文本生成特征向量。TF–IDF 用于通过计算单词在文档中出现的次数 (TF) 以及某个单词在整个语料库 (IDF) 中出现的频率来确定该单词与语料库中文档的相关性。在星火中,TF 和 IDF 是单独实现的(哈希答卷和 IDF)。

在使用TF-IDF将单词转换为特征向量之前,我们需要使用另一个转换器,分词器,将句子拆分为单个单词。这些步骤应如图 3-8 所示,代码如清单 3-8 所示。

图 3-8情感分析示例的特征转换

// Start by creating a schema for our dataset.
import org.apache.spark.sql.types._
var reviewsSchema = StructType(Array (
    StructField("text",   StringType, true),
    StructField("label",  IntegerType, true)
    ))
// Create a DataFrame from the tab-delimited text file.
// Use the "csv" format regardless if its tab or comma delimited.
// The file does not have a header, so we’ll set the header
// option to false. We’ll set delimiter to tab and use the schema
// that we just built.
val reviewsDF = spark.read.format("csv")
                .option("header", "false")
                .option("delimiter","\t")
                .schema(reviewsSchema)
                .load("/files/amazon_cells_labelled.txt")
// Review the schema.
reviewsDF.printSchema
root
 |-- text: string (nullable = true)
 |-- label: integer (nullable = true)
// Check the data.
reviewsDF.show
+--------------------+-----+
|                text|label|
+--------------------+-----+
|So there is no wa...|    0|
|Good case, Excell...|    1|
|Great for the jaw...|    1|
|Tied to charger f...|    0|
|   The mic is great.|    1|
|I have to jiggle ...|    0|
|If you have sever...|    0|
|If you are Razr o...|    1|
|Needless to say, ...|    0|
|What a waste of m...|    0|
|And the sound qua...|    1|
|He was very impre...|    1|
|If the two were s...|    0|
|Very good quality...|    1|
|The design is ver...|    0|
|Highly recommend ...|    1|
|I advise EVERYONE...|    0|
|    So Far So Good!.|    1|
|       Works great!.|    1|
|It clicks into pl...|    0|
+--------------------+-----+
only showing top 20 rows
// Let's do some row counts.
reviewsDF.createOrReplaceTempView("reviews")
spark.sql("select label,count(*) from reviews group by label").show
+-----+--------+
|label|count(1)|
+-----+--------+
|    1|     500|
|    0|     500|
+-----+--------+
// Randomly divide the dataset into training and test datasets.
val seed = 1234
val Array(trainingData, testData) = reviewsDF.randomSplit(Array(0.8, 0.2), seed)
trainingData.count
res5: Long = 827
testData.count
res6: Long = 173
// Split the sentences into words.
import org.apache.spark.ml.feature.Tokenizer
val tokenizer = new Tokenizer().setInputCol("text")
                .setOutputCol("words")
// Check the tokenized data.
val tokenizedDF = tokenizer.transform(trainingData)
tokenizedDF.show
+--------------------+-----+--------------------+
|                text|label|               words|
+--------------------+-----+--------------------+
|         (It works!)|    1|      [(it, works!)]|
|)Setup couldn't h...|    1|[)setup, couldn't...|
|* Comes with a st...|    1|[*, comes, with, ...|
|.... Item arrived...|    1|[...., item, arri...|
|1. long lasting b...|    0|[1., long, lastin...|
|2 thumbs up to th...|    1|[2, thumbs, up, t...|
|:-)Oh, the charge...|    1|[:-)oh,, the, cha...|
|   A Disappointment.|    0|[a, disappointment.]|
|A PIECE OF JUNK T...|    0|[a, piece, of, ju...|
|A good quality ba...|    1|[a, good, quality...|
|A must study for ...|    0|[a, must, study, ...|
|A pretty good pro...|    1|[a, pretty, good,...|
|A usable keyboard...|    1|[a, usable, keybo...|
|A week later afte...|    0|[a, week, later, ...|
|AFTER ARGUING WIT...|    0|[after, arguing, ...|
|AFter the first c...|    0|[after, the, firs...|
|       AMAZON SUCKS.|    0|    [amazon, sucks.]|
|     Absolutel junk.|    0|  [absolutel, junk.]|
|   Absolutely great.|    1|[absolutely, great.]|
|Adapter does not ...|    0|[adapter, does, n...|
+--------------------+-----+--------------------+
only showing top 20 rows
// Next, we'll use HashingTF to convert the tokenized words
// into fixed-length feature vector.
import org.apache.spark.ml.feature.HashingTF
val htf = new HashingTF().setNumFeatures(1000)
          .setInputCol("words")
.setOutputCol("features")
// Check the vectorized features.
val hashedDF = htf.transform(tokenizedDF)
hashedDF.show
+--------------------+-----+--------------------+--------------------+
|                text|label|               words|            features|
+--------------------+-----+--------------------+--------------------+
|         (It works!)|    1|      [(it, works!)]|(1000,[369,504],[...|
|)Setup couldn't h...|    1|[)setup, couldn't...|(1000,[299,520,53...|
|* Comes with a st...|    1|[*, comes, with, ...|(1000,[34,51,67,1...|
|.... Item arrived...|    1|[...., item, arri...|(1000,[98,133,245...|
|1. long lasting b...|    0|[1., long, lastin...|(1000,[138,258,29...|
|2 thumbs up to th...|    1|[2, thumbs, up, t...|(1000,[92,128,373...|
|:-)Oh, the charge...|    1|[:-)oh,, the, cha...|(1000,[388,497,52...|
|   A Disappointment.|    0|[a, disappointment.]|(1000,[170,386],[...|
|A PIECE OF JUNK T...|    0|[a, piece, of, ju...|(1000,[34,36,47,7...|
|A good quality ba...|    1|[a, good, quality...|(1000,[77,82,168,...|
|A must study for ...|    0|[a, must, study, ...|(1000,[23,36,104,...|
|A pretty good pro...|    1|[a, pretty, good,...|(1000,[168,170,27...|
|A usable keyboard...|    1|[a, usable, keybo...|(1000,[2,116,170,...|
|A week later afte...|    0|[a, week, later, ...|(1000,[77,122,156...|
|AFTER ARGUING WIT...|    0|[after, arguing, ...|(1000,[77,166,202...|
|AFter the first c...|    0|[after, the, firs...|(1000,[63,77,183,...|
|       AMAZON SUCKS.|    0|    [amazon, sucks.]|(1000,[828,966],[...|
|     Absolutel junk.|    0|  [absolutel, junk.]|(1000,[607,888],[...|
|   Absolutely great.|    1|[absolutely, great.]|(1000,[589,903],[...|
|Adapter does not ...|    0|[adapter, does, n...|(1000,[0,18,51,28...|
+--------------------+-----+--------------------+--------------------+
only showing top 20 rows
// We will use the naïve Bayes classifier provided by MLlib.
import org.apache.spark.ml.classification.NaiveBayes
val nb = new NaiveBayes()
// We now have all the parts that we need to assemble
// a machine learning pipeline.
import org.apache.spark.ml.Pipeline
val pipeline = new Pipeline().setStages(Array(tokenizer, htf, nb))
// Train our model using the training dataset.
val model = pipeline.fit(trainingData)
// Predict using the test dataset.
val predictions = model.transform(testData)
// Display the predictions for each review.
predictions.select("text","prediction").show
+--------------------+----------+
|                text|prediction|
+--------------------+----------+
|!I definitely reco...|       1.0|
|#1 It Works - #2 ...|       1.0|
| $50 Down the drain.|       0.0|
|A lot of websites...|       1.0|
|After charging ov...|       0.0|
|After my phone go...|       0.0|
|All in all I thin...|       1.0|
|All it took was o...|       0.0|
|Also, if your pho...|       0.0|
|And I just love t...|       1.0|
|And none of the t...|       1.0|
|         Bad Choice.|       0.0|
|Best headset ever...|       1.0|
|Big Disappointmen...|       0.0|
|Bluetooth range i...|       0.0|
|But despite these...|       0.0|
|Buyer--Be Very Ca...|       1.0|
|Can't store anyth...|       0.0|
|Chinese Forgeries...|       0.0|
|Do NOT buy if you...|       0.0|
+--------------------+----------+
only showing top 20 rows
// Evaluate our model using a binary classifier evaluator.
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
val evaluator = new BinaryClassificationEvaluator()
import org.apache.spark.ml.param.ParamMap
val paramMap = ParamMap(evaluator.metricName -> "areaUnderROC")
val auc = evaluator.evaluate(predictions, paramMap)
auc: Double = 0.5407085561497325
// Test on a positive example.
val predictions = model
.transform(sc.parallelize(Seq("This product is good")).toDF("text"))
predictions.select("text","prediction").show
+--------------------+----------+
|                text|prediction|
+--------------------+----------+
|This product is good|       1.0|
+--------------------+----------+
// Test on a negative example.
val predictions = model
.transform(sc.parallelize(Seq("This product is bad")).toDF("text"))
predictions.select("text","prediction").show
+-------------------+----------+
|               text|prediction|
+-------------------+----------+
|This product is bad|       0.0|
+-------------------+----------+

清单 3-8使用朴素贝叶斯的情感分析
有几件事可以改进我们的模型。执行其他文本预处理(如 n 元语法、词形还原和非索引字删除)在大多数自然语言处理 (NLP) 任务中很常见。我在第4章中介绍了斯坦福大学核心NLP和火花自然语言。

回归

回归是用于预测连续数值的监督式机器学习任务。流行的用例包括销售和需求预测,预测库存,房屋或商品价格以及天气预报,仅举几例。我将在第 1 章中更详细地讨论回归。

简单线性回归

线性回归用于检查一个或多个自变量与因变量之间的线性关系。对单个自变量和单个连续因变量之间关系的分析称为简单线性回归。

如图 3-9 所示,该图显示了线性攻击性,该曲线试图绘制一条直线,以最好地减少观测到的响应与预测值之间的残差平方和。

 图3-9简单的线性回归图

例子

对于我们的示例,我们将使用简单的线性回归来显示房价(因变量)如何根据该地区的平均家庭收入(自变量)变化。清单 3-9 详细介绍了代码。

import org.apache.spark.ml.regression.LinearRegression
import spark.implicits._
val dataDF = Seq(
 (50000, 302200),
 (75200, 550000),
 (90000, 680000),
 (32800, 225000),
 (41000, 275000),
 (54000, 300500),
 (72000, 525000),
 (105000, 700000),
 (88500, 673100),
 (92000, 695000),
 (53000, 320900),
 (85200, 652800),
 (157000, 890000),
 (128000, 735000),
 (71500, 523000),
 (114000, 720300),
 (33400, 265900),
 (143000, 846000),
 (68700, 492000),
 (46100, 285000)
).toDF("avg_area_income","price")
dataDF.show
+---------------+------+
|avg_area_income| price|
+---------------+------+
|          50000|302200|
|          75200|550000|
|          90000|680000|
|          32800|225000|
|          41000|275000|
|          54000|300500|
|          72000|525000|
|         105000|700000|
|          88500|673100|
|          92000|695000|
|          53000|320900|
|          85200|652800|
|         157000|890000|
|         128000|735000|
|          71500|523000|
|         114000|720300|
|          33400|265900|
|         143000|846000|
|          68700|492000|
|          46100|285000|
+---------------+------+
import org.apache.spark.ml.feature.VectorAssembler
val assembler = new VectorAssembler()
                .setInputCols(Array("avg_area_income"))
                .setOutputCol("feature")
val dataDF2 = assembler.transform(dataDF)
dataDF2.show
+---------------+------+----------+
|avg_area_income| price|   feature|
+---------------+------+----------+
|          50000|302200| [50000.0]|
|          75200|550000| [75200.0]|
|          90000|680000| [90000.0]|
|          32800|225000| [32800.0]|
|          41000|275000| [41000.0]|
|          54000|300500| [54000.0]|
|          72000|525000| [72000.0]|
|         105000|700000|[105000.0]|
|          88500|673100| [88500.0]|
|          92000|695000| [92000.0]|
|          53000|320900| [53000.0]|
|          85200|652800| [85200.0]|
|         157000|890000|[157000.0]|
|         128000|735000|[128000.0]|
|          71500|523000| [71500.0]|
|         114000|720300|[114000.0]|
|          33400|265900| [33400.0]|
|         143000|846000|[143000.0]|
|          68700|492000| [68700.0]|
|          46100|285000| [46100.0]|
+---------------+------+----------+
val lr = new LinearRegression()
         .setMaxIter(10)
         .setFeaturesCol("feature")
         .setLabelCol("price")
val model = lr.fit(dataDF2)
import org.apache.spark.ml.linalg.Vectors
val testData = spark
         .createDataFrame(Seq(Vectors.dense(75000))
         .map(Tuple1.apply))
         .toDF("feature")
val predictions = model.transform(testData)
predictions.show
+---------+------------------+
|  feature|        prediction|
+---------+------------------+
|[75000.0]|504090.35842779215|
+---------+------------------+

清单 3-9线性回归示例

多元回归与 XGBoost4J-Spark

多元回归用于更现实的场景,其中有两个或多个自变量和一个连续因变量。在实际用例中,同时具有线性和非线性功能是很常见的。基于树的集成算法(如 XGBoost)能够处理线性和非线性特征,这使其成为大多数生产环境的理想选择。在大多数情况下,使用基于树的融合(如 XGBoost)进行多元回归应该可以显著提高预测准确性。三十六

在本章前面,我们使用 XGBoost 解决了一个分类问题。由于 XGBoost 同时支持分类和回归,因此使用 XGBoost 进行回归与分类非常相似。

例子

对于多元回归示例,我们将使用一个稍微复杂的数据集,如清单 3-10 所示。数据集可以从卡格尔下载。三十七我们的目标是根据数据集中提供的属性预测房价。数据集包含七列:平均区域收入、平均区域房屋年龄、平均区域房间数、平均区域卧室数、区域人口、价格和地址。为了简单起见,我们不会使用地址字段(有用的信息可以从家庭住址(例如附近学校的位置)中派生出来)。价格是我们的因变量。

spark-shell --packages ml.dmlc:xgboost4j-spark:0.81
import org.apache.spark.sql.types._
// Define a schema for our dataset.
var pricesSchema = StructType(Array (
    StructField("avg_area_income",   DoubleType, true),
    StructField("avg_area_house_age",   DoubleType, true),
    StructField("avg_area_num_rooms",   DoubleType, true),
    StructField("avg_area_num_bedrooms",   DoubleType, true),
    StructField("area_population",   DoubleType, true),
    StructField("price",   DoubleType, true)
    ))
val dataDF = spark.read.format("csv")
             .option("header","true")
             .schema(pricesSchema)
             .load("USA_Housing.csv").na.drop()
// Inspect the dataset.
dataDF.printSchema
root
 |-- avg_area_income: double (nullable = true)
 |-- avg_area_house_age: double (nullable = true)
 |-- avg_area_num_rooms: double (nullable = true)
 |-- avg_area_num_bedrooms: double (nullable = true)
 |-- area_population: double (nullable = true)
 |-- price: double (nullable = true)
dataDF.select("avg_area_income","avg_area_house_age","avg_area_num_rooms").show
+------------------+------------------+------------------+
|   avg_area_income|avg_area_house_age|avg_area_num_rooms|
+------------------+------------------+------------------+
| 79545.45857431678| 5.682861321615587| 7.009188142792237|
| 79248.64245482568|6.0028998082752425| 6.730821019094919|
|61287.067178656784| 5.865889840310001| 8.512727430375099|
| 63345.24004622798|7.1882360945186425| 5.586728664827653|
|59982.197225708034| 5.040554523106283| 7.839387785120487|
|  80175.7541594853|4.9884077575337145| 6.104512439428879|
| 64698.46342788773| 6.025335906887153| 8.147759585023431|
| 78394.33927753085|6.9897797477182815| 6.620477995185026|
| 59927.66081334963|  5.36212556960358|6.3931209805509015|
| 81885.92718409566| 4.423671789897876| 8.167688003472351|
| 80527.47208292288|  8.09351268063935| 5.042746799645982|
| 50593.69549704281| 4.496512793097035| 7.467627404008019|
|39033.809236982364| 7.671755372854428| 7.250029317273495|
|  73163.6634410467| 6.919534825456555|5.9931879009455695|
|  69391.3801843616| 5.344776176735725| 8.406417714534253|
| 73091.86674582321| 5.443156466535474| 8.517512711137975|
| 79706.96305765743| 5.067889591058972| 8.219771123286257|
| 61929.07701808926| 4.788550241805888|5.0970095543775615|
| 63508.19429942997| 5.947165139552473| 7.187773835329727|
| 62085.27640340488| 5.739410843630574|  7.09180810424997|
+------------------+------------------+------------------+
only showing top 20 rows
dataDF.select("avg_area_num_bedrooms","area_population","price").show
+---------------------+------------------+------------------+
|avg_area_num_bedrooms|   area_population|             price|
+---------------------+------------------+------------------+
|                 4.09|23086.800502686456|1059033.5578701235|
|                 3.09| 40173.07217364482|  1505890.91484695|
|                 5.13| 36882.15939970458|1058987.9878760849|
|                 3.26| 34310.24283090706|1260616.8066294468|
|                 4.23|26354.109472103148| 630943.4893385402|
|                 4.04|26748.428424689715|1068138.0743935304|
|                 3.41| 60828.24908540716|1502055.8173744078|
|                 2.42|36516.358972493836|1573936.5644777215|
|                  2.3| 29387.39600281585| 798869.5328331633|
|                  6.1| 40149.96574921337|1545154.8126419624|
|                  4.1| 47224.35984022191| 1707045.722158058|
|                 4.49|34343.991885578806| 663732.3968963273|
|                  3.1| 39220.36146737246|1042814.0978200927|
|                 2.27|32326.123139488096|1291331.5184858206|
|                 4.37|35521.294033173246|1402818.2101658515|
|                 4.01|23929.524053267953|1306674.6599511993|
|                 3.12| 39717.81357630952|1556786.6001947748|
|                  4.3| 24595.90149782299| 528485.2467305964|
|                 5.12|35719.653052030866|1019425.9367578316|
|                 5.49|44922.106702293066|1030591.4292116085|
+---------------------+------------------+------------------+
only showing top 20 rows
val features = Array("avg_area_income","avg_area_house_age",
"avg_area_num_rooms","avg_area_num_bedrooms","area_population")
// Combine our features into a single feature vector.
import org.apache.spark.ml.feature.VectorAssembler
val assembler = new VectorAssembler()
                .setInputCols(features)
                .setOutputCol("features")
val dataDF2 = assembler.transform(dataDF)
dataDF2.select("price","features").show(20,50)
+------------------+--------------------------------------------------+
|             price|                                          features|
+------------------+--------------------------------------------------+
|1059033.5578701235|[79545.45857431678,5.682861321615587,7.00918814...|
|  1505890.91484695|[79248.64245482568,6.0028998082752425,6.7308210...|
|1058987.9878760849|[61287.067178656784,5.865889840310001,8.5127274...|
|1260616.8066294468|[63345.24004622798,7.1882360945186425,5.5867286...|
| 630943.4893385402|[59982.197225708034,5.040554523106283,7.8393877...|
|1068138.0743935304|[80175.7541594853,4.9884077575337145,6.10451243...|
|1502055.8173744078|[64698.46342788773,6.025335906887153,8.14775958...|
|1573936.5644777215|[78394.33927753085,6.9897797477182815,6.6204779...|
| 798869.5328331633|[59927.66081334963,5.36212556960358,6.393120980...|
|1545154.8126419624|[81885.92718409566,4.423671789897876,8.16768800...|
| 1707045.722158058|[80527.47208292288,8.09351268063935,5.042746799...|
| 663732.3968963273|[50593.69549704281,4.496512793097035,7.46762740...|
|1042814.0978200927|[39033.809236982364,7.671755372854428,7.2500293...|
|1291331.5184858206|[73163.6634410467,6.919534825456555,5.993187900...|
|1402818.2101658515|[69391.3801843616,5.344776176735725,8.406417714...|
|1306674.6599511993|[73091.86674582321,5.443156466535474,8.51751271...|
|1556786.6001947748|[79706.96305765743,5.067889591058972,8.21977112...|
| 528485.2467305964|[61929.07701808926,4.788550241805888,5.09700955...|
|1019425.9367578316|[63508.19429942997,5.947165139552473,7.18777383...|
|1030591.4292116085|[62085.27640340488,5.739410843630574,7.09180810...|
+------------------+--------------------------------------------------+
only showing top 20 rows
// Divide our dataset into training and test data.
val seed = 1234
val Array(trainingData, testData) = dataDF2.randomSplit(Array(0.8, 0.2), seed)
// Use XGBoost for regression.
import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel,XGBoostRegressor}
val xgb = new XGBoostRegressor()
          .setFeaturesCol("features")
          .setLabelCol("price")
// Create a parameter grid.
import org.apache.spark.ml.tuning.ParamGridBuilder
val paramGrid = new ParamGridBuilder()
                .addGrid(xgb.maxDepth, Array(6, 9))
                .addGrid(xgb.eta, Array(0.3, 0.7)).build()
paramGrid: Array[org.apache.spark.ml.param.ParamMap] =
Array({
      xgbr_bacf108db722-eta: 0.3,
      xgbr_bacf108db722-maxDepth: 6
}, {
      xgbr_bacf108db722-eta: 0.3,
      xgbr_bacf108db722-maxDepth: 9
}, {
      xgbr_bacf108db722-eta: 0.7,
      xgbr_bacf108db722-maxDepth: 6
}, {
      xgbr_bacf108db722-eta: 0.7,
      xgbr_bacf108db722-maxDepth: 9
})
// Create our evaluator.
import org.apache.spark.ml.evaluation.RegressionEvaluator
val evaluator = new RegressionEvaluator()
               .setLabelCol("price")
               .setPredictionCol("prediction")
               .setMetricName("rmse")
// Create our cross-validator.
import org.apache.spark.ml.tuning.CrossValidator
val cv = new CrossValidator()
         .setEstimator(xgb)
         .setEvaluator(evaluator)
         .setEstimatorParamMaps(paramGrid)
         .setNumFolds(3)
val model = cv.fit(trainingData)
val predictions = model.transform(testData)
predictions.select("features","price","prediction").show
+--------------------+------------------+------------+
|            features|             price|  prediction|
+--------------------+------------------+------------+
|[17796.6311895433...|302355.83597895555| 591896.9375|
|[35454.7146594754...| 1077805.577726322|   440094.75|
|[35608.9862370775...| 449331.5835333807|   672114.75|
|[38868.2503114142...| 759044.6879907805|   672114.75|
|[40752.7142433209...| 560598.5384309639| 591896.9375|
|[41007.4586732745...| 494742.5435776913|421605.28125|
|[41533.0129597444...| 682200.3005599922|505685.96875|
|[42258.7745410484...| 852703.2636757497| 591896.9375|
|[42940.1389392421...| 680418.7240122693| 591896.9375|
|[43192.1144092488...|1054606.9845532854|505685.96875|
|[43241.9824225005...| 629657.6132544072|505685.96875|
|[44328.2562966742...| 601007.3511604669|141361.53125|
|[45347.1506816944...| 541953.9056802422|441908.40625|
|[45546.6434075757...|   923830.33486809| 591896.9375|
|[45610.9384142094...|  961354.287727855|   849175.75|
|[45685.2499205068...| 867714.3838490517|441908.40625|
|[45990.1237417814...|1043968.3994445396|   849175.75|
|[46062.7542664558...| 675919.6815570832|505685.96875|
|[46367.2058588838...|268050.81474351394|  379889.625|
|[47467.4239151893...| 762144.9261238109| 591896.9375|
+--------------------+------------------+------------+
only showing top 20 rows
Listing 3-10Multiple Regression Using XGBoost4J-Spark
Let’s evaluate the model using the root mean squared error (RMSE). Residuals are a measure of the distance of the data points from the regression line. The RMSE is the standard deviation of the residuals and is used to measure the prediction error.xxxviii
val rmse = evaluator.evaluate(predictions)
rmse: Double = 438499.82356536255
// Extract the parameters.
model.bestModel.extractParamMap
res11: org.apache.spark.ml.param.ParamMap =
{
      xgbr_8da6032c61a9-alpha: 0.0,
      xgbr_8da6032c61a9-baseScore: 0.5,
      xgbr_8da6032c61a9-checkpointInterval: -1,
      xgbr_8da6032c61a9-checkpointPath: ,
      xgbr_8da6032c61a9-colsampleBylevel: 1.0,
      xgbr_8da6032c61a9-colsampleBytree: 1.0,
      xgbr_8da6032c61a9-customEval: null,
      xgbr_8da6032c61a9-customObj: null,
      xgbr_8da6032c61a9-eta: 0.7,
      xgbr_8da6032c61a9-evalMetric: rmse,
      xgbr_8da6032c61a9-featuresCol: features,
      xgbr_8da6032c61a9-gamma: 0.0,
      xgbr_8da6032c61a9-growPolicy: depthwise,
      xgbr_8da6032c61a9-labelCol: price,
      xgbr_8da6032c61a9-lambda: 1.0,
      xgbr_8da6032c61a9-lambdaBias: 0.0,
      xgbr_8da6032c61a9-maxBin: 16,
      xgbr_8da6032c61a9-maxDeltaStep: 0.0,
      xgbr_8da6032c61a9-maxDepth: 9,
      xgbr_8da6032c61a9-minChildWeight: 1.0,
      xgbr_8da6032c61a9-missing: NaN,
      xgbr_8da6032c61a9-normalizeType: tree,
      xgbr_8da6032c61a9-nthread: 1,
      xgbr_8da6032c61a9-numEarlyStoppingRounds: 0,
      xgbr_8da6032c61a9-numRound: 1,
      xgbr_8da6032c61a9-numWorkers: 1,
      xgbr_8da6032c61a9-objective: reg:linear,
      xgbr_8da6032c61a9-predictionCol: prediction,
      xgbr_8da6032c61a9-rateDrop: 0.0,
      xgbr_8da6032c61a9-sampleType: uniform,
      xgbr_8da6032c61a9-scalePosWeight: 1.0,
      xgbr_8da6032c61a9-seed: 0,
      xgbr_8da6032c61a9-silent: 0,
      xgbr_8da6032c61a9-sketchEps: 0.03,
      xgbr_8da6032c61a9-skipDrop: 0.0,
      xgbr_8da6032c61a9-subsample: 1.0,
      xgbr_8da6032c61a9-timeoutRequestWorkers: 1800000,
      xgbr_8da6032c61a9-trackerConf: TrackerConf(0,python),
      xgbr_8da6032c61a9-trainTestRatio: 1.0,
      xgbr_8da6032c61a9-treeLimit: 0,
      xgbr_8da6032c61a9-treeMethod: auto,
      xgbr_8da6032c61a9-useExternalMemory: false
}


使用LightGBM进行多元回归

在清单 3-11 中,我们将使用轻量级 GBM。光 GBM 附带了专门用于回归任务的光 GBM 回归器类。我们将重用外壳数据集和前面 XGBoost 示例中的大部分代码。

spark-shell --packages Azure:mmlspark:0.15
var pricesSchema = StructType(Array (
    StructField("avg_area_income",   DoubleType, true),
    StructField("avg_area_house_age",   DoubleType, true),
    StructField("avg_area_num_rooms",   DoubleType, true),
    StructField("avg_area_num_bedrooms",   DoubleType, true),
    StructField("area_population",   DoubleType, true),
    StructField("price",   DoubleType, true)
    ))
val dataDF = spark.read.format("csv")
             .option("header","true")
             .schema(pricesSchema)
             .load("USA_Housing.csv")
             .na.drop()
dataDF.printSchema
root
 |-- avg_area_income: double (nullable = true)
 |-- avg_area_house_age: double (nullable = true)
 |-- avg_area_num_rooms: double (nullable = true)
 |-- avg_area_num_bedrooms: double (nullable = true)
 |-- area_population: double (nullable = true)
 |-- price: double (nullable = true)
dataDF.select("avg_area_income","avg_area_house_age",
"avg_area_num_rooms")
.show
+------------------+------------------+------------------+
|   avg_area_income|avg_area_house_age|avg_area_num_rooms|
+------------------+------------------+------------------+
| 79545.45857431678| 5.682861321615587| 7.009188142792237|
| 79248.64245482568|6.0028998082752425| 6.730821019094919|
|61287.067178656784| 5.865889840310001| 8.512727430375099|
| 63345.24004622798|7.1882360945186425| 5.586728664827653|
|59982.197225708034| 5.040554523106283| 7.839387785120487|
|  80175.7541594853|4.9884077575337145| 6.104512439428879|
| 64698.46342788773| 6.025335906887153| 8.147759585023431|
| 78394.33927753085|6.9897797477182815| 6.620477995185026|
| 59927.66081334963|  5.36212556960358|6.3931209805509015|
| 81885.92718409566| 4.423671789897876| 8.167688003472351|
| 80527.47208292288|  8.09351268063935| 5.042746799645982|
| 50593.69549704281| 4.496512793097035| 7.467627404008019|
|39033.809236982364| 7.671755372854428| 7.250029317273495|
|  73163.6634410467| 6.919534825456555|5.9931879009455695|
|  69391.3801843616| 5.344776176735725| 8.406417714534253|
| 73091.86674582321| 5.443156466535474| 8.517512711137975|
| 79706.96305765743| 5.067889591058972| 8.219771123286257|
| 61929.07701808926| 4.788550241805888|5.0970095543775615|
| 63508.19429942997| 5.947165139552473| 7.187773835329727|
| 62085.27640340488| 5.739410843630574|  7.09180810424997|
+------------------+------------------+------------------+
dataDF.select("avg_area_num_bedrooms","area_population","price").show
+---------------------+------------------+------------------+
|avg_area_num_bedrooms|   area_population|             price|
+---------------------+------------------+------------------+
|                 4.09|23086.800502686456|1059033.5578701235|
|                 3.09| 40173.07217364482|  1505890.91484695|
|                 5.13| 36882.15939970458|1058987.9878760849|
|                 3.26| 34310.24283090706|1260616.8066294468|
|                 4.23|26354.109472103148| 630943.4893385402|
|                 4.04|26748.428424689715|1068138.0743935304|
|                 3.41| 60828.24908540716|1502055.8173744078|
|                 2.42|36516.358972493836|1573936.5644777215|
|                  2.3| 29387.39600281585| 798869.5328331633|
|                  6.1| 40149.96574921337|1545154.8126419624|
|                  4.1| 47224.35984022191| 1707045.722158058|
|                 4.49|34343.991885578806| 663732.3968963273|
|                  3.1| 39220.36146737246|1042814.0978200927|
|                 2.27|32326.123139488096|1291331.5184858206|
|                 4.37|35521.294033173246|1402818.2101658515|
|                 4.01|23929.524053267953|1306674.6599511993|
|                 3.12| 39717.81357630952|1556786.6001947748|
|                  4.3| 24595.90149782299| 528485.2467305964|
|                 5.12|35719.653052030866|1019425.9367578316|
|                 5.49|44922.106702293066|1030591.4292116085|
+---------------------+------------------+------------------+
only showing top 20 rows
val features = Array("avg_area_income","avg_area_house_age",
"avg_area_num_rooms","avg_area_num_bedrooms","area_population")
import org.apache.spark.ml.feature.VectorAssembler
val assembler = new VectorAssembler()
                .setInputCols(features)
                .setOutputCol("features")
val dataDF2 = assembler.transform(dataDF)
dataDF2.select("price","features").show(20,50)
+------------------+--------------------------------------------------+
|             price|                                          features|
+------------------+--------------------------------------------------+
|1059033.5578701235|[79545.45857431678,5.682861321615587,7.00918814...|
|  1505890.91484695|[79248.64245482568,6.0028998082752425,6.7308210...|
|1058987.9878760849|[61287.067178656784,5.865889840310001,8.5127274...|
|1260616.8066294468|[63345.24004622798,7.1882360945186425,5.5867286...|
| 630943.4893385402|[59982.197225708034,5.040554523106283,7.8393877...|
|1068138.0743935304|[80175.7541594853,4.9884077575337145,6.10451243...|
|1502055.8173744078|[64698.46342788773,6.025335906887153,8.14775958...|
|1573936.5644777215|[78394.33927753085,6.9897797477182815,6.6204779...|
| 798869.5328331633|[59927.66081334963,5.36212556960358,6.393120980...|
|1545154.8126419624|[81885.92718409566,4.423671789897876,8.16768800...|
| 1707045.722158058|[80527.47208292288,8.09351268063935,5.042746799...|
| 663732.3968963273|[50593.69549704281,4.496512793097035,7.46762740...|
|1042814.0978200927|[39033.809236982364,7.671755372854428,7.2500293...|
|1291331.5184858206|[73163.6634410467,6.919534825456555,5.993187900...|
|1402818.2101658515|[69391.3801843616,5.344776176735725,8.406417714...|
|1306674.6599511993|[73091.86674582321,5.443156466535474,8.51751271...|
|1556786.6001947748|[79706.96305765743,5.067889591058972,8.21977112...|
| 528485.2467305964|[61929.07701808926,4.788550241805888,5.09700955...|
|1019425.9367578316|[63508.19429942997,5.947165139552473,7.18777383...|
|1030591.4292116085|[62085.27640340488,5.739410843630574,7.09180810...|
+------------------+--------------------------------------------------+
only showing top 20 rows
val seed = 1234
val Array(trainingData, testData) = dataDF2.randomSplit(Array(0.8, 0.2), seed)
import com.microsoft.ml.spark.{LightGBMRegressionModel,LightGBMRegressor}
val lightgbm = new LightGBMRegressor()
               .setFeaturesCol("features")
               .setLabelCol("price")
               .setObjective("regression")
import org.apache.spark.ml.tuning.ParamGridBuilder
val paramGrid = new ParamGridBuilder()
                .addGrid(lightgbm.numLeaves, Array(6, 9))
                .addGrid(lightgbm.numIterations, Array(10, 15))
                .addGrid(lightgbm.maxDepth, Array(2, 3, 4))
                .build()
paramGrid: Array[org.apache.spark.ml.param.ParamMap] =
Array({
        LightGBMRegressor_f969f7c475b5-maxDepth: 2,
        LightGBMRegressor_f969f7c475b5-numIterations: 10,
        LightGBMRegressor_f969f7c475b5-numLeaves: 6
}, {
        LightGBMRegressor_f969f7c475b5-maxDepth: 3,
        LightGBMRegressor_f969f7c475b5-numIterations: 10,
        LightGBMRegressor_f969f7c475b5-numLeaves: 6
}, {
        LightGBMRegressor_f969f7c475b5-maxDepth: 4,
        LightGBMRegressor_f969f7c475b5-numIterations: 10,
        LightGBMRegressor_f969f7c475b5-numLeaves: 6
}, {
        LightGBMRegressor_f969f7c475b5-maxDepth: 2,
        LightGBMRegressor_f969f7c475b5-numIterations: 10,
        LightGBMRegressor_f969f7c475b5-numLeaves: 9
}, {
        LightGBMRegressor_f969f7c475b5-maxDepth: 3,
        LightGBMRegressor_f969f7c475b5-numIterations: 10,
        LightGBMRegressor_f969f7c475b5-numLeaves: 9
}, {
        Lig...
import org.apache.spark.ml.evaluation.RegressionEvaluator
val evaluator = new RegressionEvaluator()
                .setLabelCol("price")
                .setPredictionCol("prediction")
                .setMetricName("rmse")
import org.apache.spark.ml.tuning.CrossValidator
val cv = new CrossValidator()
         .setEstimator(lightgbm)
         .setEvaluator(evaluator)
         .setEstimatorParamMaps(paramGrid)
         .setNumFolds(3)
val model = cv.fit(trainingData)
val predictions = model.transform(testData)
predictions.select("features","price","prediction").show
+--------------------+------------------+------------------+
|            features|             price|        prediction|
+--------------------+------------------+------------------+
|[17796.6311895433...|302355.83597895555| 965317.3181705693|
|[35454.7146594754...| 1077805.577726322|1093159.8506664087|
|[35608.9862370775...| 449331.5835333807|1061505.7131801855|
|[38868.2503114142...| 759044.6879907805|1061505.7131801855|
|[40752.7142433209...| 560598.5384309639| 974582.8481703462|
|[41007.4586732745...| 494742.5435776913| 881891.5646432829|
|[41533.0129597444...| 682200.3005599922| 966417.0064436384|
|[42258.7745410484...| 852703.2636757497|1070641.7611960804|
|[42940.1389392421...| 680418.7240122693|1028986.6314725328|
|[43192.1144092488...|1054606.9845532854|1087808.2361520242|
|[43241.9824225005...| 629657.6132544072| 889012.3734817103|
|[44328.2562966742...| 601007.3511604669| 828175.3829271109|
|[45347.1506816944...| 541953.9056802422| 860754.7467075661|
|[45546.6434075757...|   923830.33486809| 950407.7970842035|
|[45610.9384142094...|  961354.287727855|1175429.1179985087|
|[45685.2499205068...| 867714.3838490517|  828812.007346283|
|[45990.1237417814...|1043968.3994445396|1204501.1530193759|
|[46062.7542664558...| 675919.6815570832| 973273.6042265462|
|[46367.2058588838...|268050.81474351394| 761576.9192149616|
|[47467.4239151893...| 762144.9261238109| 951908.0117790927|
+--------------------+------------------+------------------+
only showing top 20 rows
val rmse = evaluator.evaluate(predictions)
rmse: Double = 198601.74726198777

清单 3-11使用光GBM的多元回归
让我们提取每个功能的特征重要性分数。

val model = lightgbm.fit(trainingData)
model.getFeatureImportances("gain")
res7: Array[Double] = Array(1.110789482705408E15, 5.69355224816896E14, 3.25231517467648E14, 1.16104381056E13, 4.84685311277056E14)

通过将列表中输出的顺序与特征向量中特征的顺序(avg_area_income、avg_area_house_age、avg_area_num_rooms、avg_area_num_bedrooms、area_population)进行匹配,看起来avg_area_income是我们最重要的特征,其次是avg_area_house_age、area_population、 和avg_area_num_rooms。最不重要的功能是avg_area_num_bedrooms。

总结
我讨论了火花 MLlib 中包含的一些最流行的监督学习算法,以及外部可用的更新算法,如 XGBoost 和 LightGBM。虽然网上有大量关于 XGBoost 和闪电般GBM 的蟒蛇文档,但 Spark 的信息和示例是有限的。本章旨在帮助弥合这一差距。

  • 4
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Sonhhxg_柒

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值