xgboost之spark上运行-scala接口

概述

xgboost可以在spark上运行,我用的xgboost的版本是0.7的版本,目前只支持spark2.0以上版本上运行,

编译好jar包,加载到maven仓库里面去:

  
  
  1. mvn install:install-file -Dfile=xgboost4j-spark-0.7-jar-with-dependencies.jar -DgroupId=ml.dmlc -DartifactId=xgboost4j-spark -Dversion=0.7 -Dpackaging=jar


添加依赖:

<dependency>
			<groupId>ml.dmlc</groupId>
			<artifactId>xgboost4j-spark</artifactId>
			<version>0.7</version>
		</dependency>
		<dependency>
			<groupId>org.apache.spark</groupId>
			<artifactId>spark-core_2.10</artifactId>
			<version>2.0.0</version>
		</dependency>
		<dependency>
			<groupId>org.apache.spark</groupId>
			<artifactId>spark-mllib_2.10</artifactId>
			<version>2.0.0</version>
		</dependency>
	</dependencies>




RDD接口:


package com.meituan.spark_xgboost
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.{ SparkConf, SparkContext }
import ml.dmlc.xgboost4j.scala.spark.XGBoost
import org.apache.spark.sql.{ SparkSession, Row }
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
object XgboostR {


  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
    val spark = SparkSession.builder.master("local").appName("example").
      config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse").
      config("spark.sql.shuffle.partitions", "20").getOrCreate()
    spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/"
  val trainString = "agaricus.txt.train"
  val testString = "agaricus.txt.test"
    val train = MLUtils.loadLibSVMFile(spark.sparkContext, path + trainString)
    val test = MLUtils.loadLibSVMFile(spark.sparkContext, path + testString)
    val traindata = train.map { x =>
      val f = x.features.toArray
      val v = x.label
      LabeledPoint(v, Vectors.dense(f))
    }
    val testdata = test.map { x =>
      val f = x.features.toArray
      val v = x.label
       Vectors.dense(f)
    }
    

    val numRound = 15
    
     //"objective" -> "reg:linear", //定义学习任务及相应的学习目标
      //"eval_metric" -> "rmse", //校验数据所需要的评价指标  用于做回归
    
    val paramMap = List(
      "eta" -> 1f,
      "max_depth" ->5, //数的最大深度。缺省值为6 ,取值范围为:[1,∞] 
      "silent" -> 1, //取0时表示打印出运行时信息,取1时表示以缄默方式运行,不打印运行时信息。缺省值为0 
      "objective" -> "binary:logistic", //定义学习任务及相应的学习目标
      "lambda"->2.5,
      "nthread" -> 1 //XGBoost运行时的线程数。缺省值是当前系统可以获得的最大线程数
      ).toMap
    println(paramMap)
    

    val model = XGBoost.trainWithRDD(traindata, paramMap, numRound, 55, null, null, useExternalMemory = false, Float.NaN)
    print("sucess")
 
    val result=model.predict(testdata)
    result.take(10).foreach(println)
    spark.stop();
   
  }

}


DataFrame接口:

package com.meituan.spark_xgboost
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.{ SparkConf, SparkContext }
import ml.dmlc.xgboost4j.scala.spark.XGBoost
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.sql.{ SparkSession, Row }
object XgboostD {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
    val spark = SparkSession.builder.master("local").appName("example").
      config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse").
      config("spark.sql.shuffle.partitions", "20").getOrCreate()
    spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/"
    val trainString = "agaricus.txt.train"
    val testString = "agaricus.txt.test"

    val train = spark.read.format("libsvm").load(path + trainString).toDF("label", "feature")

    val test = spark.read.format("libsvm").load(path + testString).toDF("label", "feature")

    val numRound = 15

    //"objective" -> "reg:linear", //定义学习任务及相应的学习目标
    //"eval_metric" -> "rmse", //校验数据所需要的评价指标  用于做回归

    val paramMap = List(
      "eta" -> 1f,
      "max_depth" -> 5, //数的最大深度。缺省值为6 ,取值范围为:[1,∞] 
      "silent" -> 1, //取0时表示打印出运行时信息,取1时表示以缄默方式运行,不打印运行时信息。缺省值为0 
      "objective" -> "binary:logistic", //定义学习任务及相应的学习目标
      "lambda" -> 2.5,
      "nthread" -> 1 //XGBoost运行时的线程数。缺省值是当前系统可以获得的最大线程数
      ).toMap
    val model = XGBoost.trainWithDataFrame(train, paramMap, numRound, 45, obj = null, eval = null, useExternalMemory = false, Float.NaN, "feature", "label")
    val predict = model.transform(test)

    val scoreAndLabels = predict.select(model.getPredictionCol, model.getLabelCol)
      .rdd
      .map { case Row(score: Double, label: Double) => (score, label) }

    //get the auc
    val metric = new BinaryClassificationMetrics(scoreAndLabels)
    val auc = metric.areaUnderROC()
    println("auc:" + auc)

  }

}



  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 18
    评论
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值