spark XGBoost算法demo

1.运行环境配置

     该算法需要运行Linux环境下,运行的版本为:spark2.4.0,scala 2.11

2.maven配置


    <properties>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
        <scala.version>2.11</scala.version>
        <spark.version>2.4.0</spark.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>ml.dmlc</groupId>
            <artifactId>xgboost4j</artifactId>
            <version>0.72</version>
        </dependency>
        <dependency>
            <groupId>ml.dmlc</groupId>
            <artifactId>xgboost4j-spark</artifactId>
            <version>0.72</version>
        </dependency>
    </dependencies>
    <build>
        <plugins>
            <plugin>
                <groupId>org.scala-tools</groupId>
                <artifactId>maven-scala-plugin</artifactId>
                <version>2.15.2</version>
                <executions>
                    <execution>
                        <goals>
                            <goal>compile</goal>
                            <goal>testCompile</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>

            <plugin>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.6.0</version>
                <configuration>
                    <source>1.8</source>
                    <target>1.8</target>
                </configuration>
            </plugin>

            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-surefire-plugin</artifactId>
                <version>2.19</version>
                <configuration>
                    <skip>true</skip>
                </configuration>
            </plugin>
            <!-- 打出jar包引用关联包 -->
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-jar-plugin</artifactId>
                <version>2.4</version>
                <configuration>
                    <archive>
                        <manifest>
                            <addClasspath>true</addClasspath>
                            <classpathPrefix>lib/</classpathPrefix>
                            <!--<mainClass>com.caxs.artemis.model.schedule.ModelInvoke</mainClass>-->
                        </manifest>
                    </archive>
                </configuration>
            </plugin>
            <!-- 将依赖包放到lib文件夹中 -->
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-dependency-plugin</artifactId>
                <executions>
                    <execution>
                        <id>copy</id>
                        <phase>package</phase>
                        <goals>
                            <goal>copy-dependencies</goal>
                        </goals>
                        <configuration>
                            <outputDirectory>
                                ${project.build.directory}/lib
                            </outputDirectory>
                        </configuration>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>

3.运行demo

package spark

import ml.dmlc.xgboost4j.scala.spark.XGBoost
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.{DataFrame, SparkSession}

/**
  * author     :test-abc
  * date       :Created in 2019/9/3 11:04
  * description:xgboost demo
  * modified By:
  */

object XgboostDemo {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .appName("SparkSql")
//      .master("local[2]")
      .getOrCreate()
    //准备示例数据,将数据转为dataframe
    import spark.implicits._
    val dataList: List[(Int, Double, Double, Double, Double, Double, Double)] = List(
      (0, 8.9255, -6.7863, 11.9081, 5.093, 11.4607, -9.2834),
      (1, 11.5006, -4.1473, 13.8588, 5.389, 12.3622, 7.0433),
      (0, 8.6093, -2.7457, 12.0805, 7.8928, 10.5825, -9.0837),
      (1, 11.0604, -2.1518, 8.9522, 7.1957, 12.5846, -1.8361),
      (1, 9.8369, -1.4834, 12.8746, 6.6375, 12.2772, 2.4486),
      (1, 11.4763, -2.3182, 12.608, 8.6264, 10.9621, 3.5609),
      (0, 11.8091, -0.0832, 9.3494, 4.2916, 11.1355, -8.0198),
      (0, 13.558, -7.9881, 13.8776, 7.5985, 8.6543, 0.831),
      (0, 16.1071, 2.4426, 13.9307, 5.6327, 8.8014, 6.163),
      (1, 12.5088, 1.9743, 8.896, 5.4508, 13.6043, -16.2859),
      (0, 5.0702, -0.5447, 9.59, 4.2987, 12.391, -18.8687),
      (0, 12.7188, -7.975, 10.3757, 9.0101, 12.857, -12.0852),
      (0, 8.7671, -4.6154, 9.7242, 7.4242, 9.0254, 1.4247),
      (1, 16.3699, 1.5934, 16.7395, 7.333, 12.145, 5.9004),
      (0, 13.808, 5.0514, 17.2611, 8.512, 12.8517, -9.1622),
      (0, 3.9416, 2.6562, 13.3633, 6.8895, 12.2806, -16.162),
      (0, 5.0615, 0.2689, 15.1325, 3.6587, 13.5276, -6.5477),
      (1, 8.4199, -1.8128, 8.1202, 5.3955, 9.7184, -17.839),
      (0, 4.875, 1.2646, 11.919, 8.465, 10.7203, -0.6707),
      (1, 4.409, -0.7863, 15.1828, 8.0631, 11.2831, -0.7356))

    val inputDF: DataFrame = dataList.toDF("label", "feature1", "feature2", "feature3", "feature4", "feature5", "feature6")
    //将需要转换的列合并为向量列
    val transCols: Array[String] = Array("feature1", "feature2", "feature3", "feature4", "feature5", "feature6")
    val assembler: VectorAssembler = new VectorAssembler().setInputCols(transCols).setOutputCol("features")
    val xGBoostTrainInput: DataFrame = assembler.transform(inputDF).select("features","label")
    xGBoostTrainInput.show(10)

//    val paramMap = List(
//      "eta" -> 0.01, //学习率
//      "gamma" -> 0.1, //用于控制是否后剪枝的参数,越大越保守,一般0.1、0.2这样子。
//      "lambda" -> 2, //控制模型复杂度的权重值的L2正则化项参数,参数越大,模型越不容易过拟合。
//      "subsample" -> 0.8, //随机采样训练样本
//      "colsample_bytree" -> 0.8, //生成树时进行的列采样
//      "max_depth" -> 5, //构建树的深度,越大越容易过拟合
//      "min_child_weight" -> 5,
//      "objective" -> "multi:softprob",  //定义学习任务及相应的学习目标
//      "eval_metric" -> "merror",
//      "num_class" -> 21
//    ).toMap

    val paramMap = List(
      "colsample_bytree" -> 1,
      "eta" -> 0.05f, //就是学习率
      "max_depth" -> 8, //树的最大深度
      "min_child_weight" -> 5, //
      "n_estimators" -> 120,
      "subsample" -> 0.7
    ).toMap


    //模型训练
    val xgBoostModel = XGBoost.trainWithDataFrame(xGBoostTrainInput, paramMap, round = 10, nWorkers = 3,
      useExternalMemory = true, featureCol = "features", labelCol = "label")

    //准备预测数据
    val testList: List[( Double, Double, Double, Double, Double, Double)] = List(
      ( 8.9225, -6.7863, 11.9081, 5.093, 11.4607, -9.2834),
      ( 11.5006, -4.1473, 13.8588, 5.389, 12.3622, 7.0433),
      ( 8.6093, -2.7457, 12.0805, 7.8928, 10.5825, -9.0837),
      ( 11.0604, -2.1518, 8.9522, 7.1957, 12.5846, -1.8361),
      ( 9.8369, -11.4834, 12.8746, 6.6375, 12.2772, 2.4486),
      ( 11.4763, -2.3182, 12.608, 8.6264, 10.9621, 3.5609),
      ( 11.8091, -10.0832, 9.3494, 4.2916, 11.1355, -8.0198),
      ( 13.558, -7.9881, 13.8776, 7.5985, 8.6543, 0.831),
      ( 16.1071, 1.4426, 13.9307, 5.6327, 8.8014, 6.163),
      ( 12.5088, 2.9743, 8.896, 5.4508, 13.6043, -16.2859),
      ( 5.0702, -0.5447, 9.59, 4.2987, 12.391, -18.8687),
      ( 12.7188, -7.975, 10.3757, 9.0101, 12.857, -12.0852),
      ( 8.7671, -4.6154, 8.7242, 7.4242, 9.0254, 1.4247),
      ( 16.3699, 1.5934, 16.7395, 7.333, 12.145, 5.9004),
      ( 13.808, 5.0514, 17.2611, 8.512, 12.8517, -9.1622),
      ( 3.9416, 2.6562, 13.3633, 6.8895, 12.2806, -16.162),
      ( 5.0615, 0.2689, 15.1325, 3.6587, 13.5276, -6.5477),
      ( 8.4199, -1.8128, 9.1202, 5.3955, 9.7184, -17.839),
      ( 5.875, 1.2646, 11.919, 8.465, 10.7203, -0.6707),
      ( 5.409, -0.7863, 15.1828, 8.0631, 11.2831, -0.7356))

    val testDf: DataFrame = testList.toDF("feature1", "feature2", "feature3", "feature4", "feature5", "feature6")
    //将测试数据集转为向量
    val xGBoostTestInput: DataFrame = assembler.transform(testDf).select("features")
    xGBoostTestInput.show(10)
    //模型预测
    val output: DataFrame = xgBoostModel.transform(xGBoostTestInput)
    output.show()
    spark.close()
  }
}

运行结果为:

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值