书接上回
在上篇中,我们介绍了XGBoost的原生接口使用方法,以及sklearn版本的接口。本篇我们再结合Scala/Spark来聊聊,以体现XGBoost在工程上的易用性。
Spark是基于Scala原生语言开发的一个分布式迭代计算平台,其中MLLib模块包括了很多机器学习算法包(但比起Sklearn来肯定还是少的)。
Scala 是一门面向对象+函数式JVM语言,需要编译后才能执行,但它提供了像Python那样的交互式编程方式,调试代码非常方便。推荐有好奇心的同学去了解下。
Scala + XGBoost
官方DEMO
官方网站上的一段DEMO程序如下:
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.XGBoost
object XGBoostScalaExample {
def main(args: Array[String]) {
// read trainining data, available at xgboost/demo/data
val trainData =
new DMatrix("/path/to/agaricus.txt.train")
// define parameters
val paramMap = List(
"eta" -> 0.1,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
// number of iterations
val round = 2
// train the model
val model = XGBoost.train(trainData, paramMap, round)
// run prediction
val predTrain = model.predict(trainData)
// save model to the file.
model.saveModel("/local/path/to/model")
}
下面我们来解剖下这段小程序,以做到以点带面地了解其使用细节。
环境配置
既然Scala是一门编译型语言,我们先需要搞清楚怎么在里面去使用到XBGoost。对于Java而言,我们可以手动下载软件的JAR包后导入工程,也可以通过配置Maven依赖来自动下载依赖的JAR包。而Scala天生地依赖于JVM生态,几乎所有java类JAR包都适用于它。
在[https://github.com/dmlc/xgboost/tree/master/jvm-packages] 这个页面上我们可以找到适用于JVM的Maven依赖:
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>latest_version_num</version>
</dependency>
官网上说最新的版本号可以从https://github.com/dmlc/xgboost/releases上面去找,这里我们看到稳定版本已经有1.0.0了。
但事实上,如果直接把上面依赖项中的latest_version_num改为1.0.0,maven自动是查找不到的。所以我们人肉去maven repository 网站上(https://mvnrepository.com/)找一下更靠谱,会发现这里其实更新到0.90版本。
点击进去后,获得如下的依赖项:
<!-- https://mvnrepository.com/artifact/ml.dmlc/xgboost4j -->
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId