spark 实现XGB花瓣预测(多分类)

11

package com.fwmagic.spark.xgboost

import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._

object SparkMLlibPipeline {
    def main(args: Array[String]): Unit = {

        var inputPath: String = ""
        var nativeModelPath = ""
        var pipelineModelPath = ""
        if (args.length == 0) {
            println("Using local path")
            inputPath = "/Users/zz/Desktop/com.zz.spark/spark-ml/skhmagic-spark-mllib/data/xgboost/iris.data"
            nativeModelPath = "/Users/zz/Desktop/com.zz.spark/spark-ml/skhmagic-spark-mllib/model_save/xgboost/native_model"
            pipelineModelPath = "/Users/zz/Desktop/com.zz.spark/spark-ml/skhmagic-spark-mllib/model_save/xgboost/pipeline_model"
        } else if (args.length == 3) {
            println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path")
            inputPath = args(0)
            nativeModelPath = args(1)
            pipelineModelPath = args(2)
        } else if (args.length != 3) {
            println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path")
            sys.exit(1)
        }

      val spark = SparkSession.builder()
          .master("local")
          .appName("XGBoost4J-Spark Pipeline Example")
          .getOrCreate()

      // Load dataset
      val schema = new StructType(Array(
          StructField("sepal length", DoubleType, true),
          StructField("sepal width", DoubleType, true),
          StructField("petal length", DoubleType, true),
          StructField("petal width", DoubleType, true),
          StructField("class", StringType, true)))

      val rawInput = spark.read.schema(schema).csv(inputPath)
      rawInput.show(5,false)

      // Split training and test dataset
      val Array(training, test) = rawInput.randomSplit(Array(0.8, 0.2), 123)
      // val Array(training, eval1, eval2, test) = rawInput.randomSplit(Array(0.6, 0.1, 0.1, 0.2))

      // Build ML pipeline, it includes 4 stages:


      // 1, Assemble all features into a single vector column.      
      /**
       * 获取所有列转为Array数组
       *
       * @param df
       * @return
       */
      def getColumnArray(df: DataFrame): Array[String] = {
          //drop column : classIndex
          var columns: Array[String] = df.columns.dropRight(1)
          return columns
      }

      val assembler = new VectorAssembler()
      .setInputCols(Array("sepal length", "sepal width", "petal length", "petal width")).
      //.setInputCols(getColumnArray(labelTransformed))
      .setOutputCol("features")

      // 2, From string label to indexed double label.
      val labelIndexer = new StringIndexer()
          .setInputCol("class")
          .setOutputCol("classIndex")
          .fit(training)
      // 3, Use XGBoostClassifier to train classification model.
      // 注意!!!这个num_workers 必须小于等于 local[5] 线程数,否则会出现程序卡死现象.
      val xgbParam = Map(
          "eta" -> 0.1f,
          "max_depth" -> 2,
          "objective" -> "multi:softprob",
          // "objective" -> "binary:logistic",
          "num_class" -> 3,
          "num_round" -> 100,
          "num_workers" -> 5,
          // "tree_method" -> treeMethod,
          // "eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2)
      )

      // 创建xgboost函数,指定特征向量和标签
      val booster = new XGBoostClassifier(xgbParam)
          .setFeaturesCol("features")
          .setLabelCol("classIndex")

      // 4, Convert indexed double label back to original string label.
      val labelConverter = new IndexToString()
          .setInputCol("prediction")
          .setOutputCol("realLabel")
          .setLabels(labelIndexer.labels)

      val pipeline = new Pipeline()
          .setStages(Array(assembler, labelIndexer, booster, labelConverter))
      println("Start Trainning ......")
      val model = pipeline.fit(training)
      println("End Trainning ......")

      // Batch prediction
      println("Predicting ...")
      val prediction = model.transform(test)
      prediction.show(5,false)

      // Model evaluation
      val evaluator = new MulticlassClassificationEvaluator()
          .setLabelCol("classIndex")
          .setPredictionCol("prediction")
      val accuracy = evaluator.evaluate(prediction)
      println("The model accuracy is : " + accuracy)

      // Tune model using cross validation
      val paramGrid = new ParamGridBuilder()
          .addGrid(booster.maxDepth, Array(3, 8))
          .addGrid(booster.eta, Array(0.2, 0.6))
          .build()
      val cv = new CrossValidator()
          .setEstimator(pipeline)
          .setEvaluator(evaluator)
          .setEstimatorParamMaps(paramGrid)
          .setNumFolds(3)

      println("Start CV Trainning ......")
      val cvModel = cv.fit(training)
      println("End CV Trainning ......")

      val bestModel = cvModel.bestModel.asInstanceOf[PipelineModel].stages(2)
          .asInstanceOf[XGBoostClassificationModel]
      println("The params of best XGBoostClassification model : " +
          bestModel.extractParamMap())
      println("The training summary of best XGBoostClassificationModel : " +
          bestModel.summary)

      // Export the XGBoostClassificationModel as local XGBoost model,
      // then you can load it back in local Python environment.
      bestModel.nativeBooster.saveModel(nativeModelPath)

      // ML pipeline persistence
      model.write.overwrite().save(pipelineModelPath)

      // Load a saved model and serving
      val model2 = PipelineModel.load(pipelineModelPath)
      println("CV Best Predicting ...")
      model2.transform(test).show(false)
 
      spark.close()
    }
}

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值