Spark学习——利用Mleap部署spark pipeline模型

1.需要的依赖

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>2.2.0</version>
        </dependency>
        <dependency>
            <groupId>ml.combust.mleap</groupId>
            <artifactId>mleap-spark_2.11</artifactId>
            <version>0.11.0</version>
        </dependency>
         <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.11</artifactId>
            <version>2.2.0</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-hive_2.11</artifactId>
            <version>2.2.0}</version>
        </dependency>

2.代码

2.1 数据预处理、模型训练、存储

import ml.combust.bundle.BundleFile
import ml.combust.mleap.spark.SparkSupport._
import org.apache.spark.ml.bundle.SparkBundleContext
import org.apache.spark.ml.classification.{GBTClassifier, LogisticRegression}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler}
import org.apache.spark.sql._
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.{Pipeline, PipelineStage}
import resource.managed

object trainModelLeap {
 def main(args: Array[String]): Unit = {
    val sparkConf = new SparkConf().setMaster("local[2]").setAppName("zhunshibao_test")
    val sc = new SparkContext(sparkConf)
    val spark = SparkSession.builder().config(sc.getConf).config("hive.metastore.uris", "thrift://10.202.77.200:9083").enableHiveSupport().getOrCreate()

    var data = spark.sql(constantConfig.testTrainDataSetSql).na.drop

    val splited = data.randomSplit(Array(0.8, 0.2), 2L)
    var trainSet = splited(0)
    var testSet = splited(1)
    trainSet.show(5)

    var dataProcessList = List[PipelineStage]()
    /** StringtoIndex **/
    val stringColumns = Array("new_cargo","oprid")
    var StringColumnsInc = List[String]()

    for (filed <- stringColumns) {
      val indexer = new StringIndexer().setInputCol(filed).setOutputCol(filed + "Inc").setHandleInvalid("skip")
      dataProcessList = dataProcessList :+ indexer
      StringColumnsInc = StringColumnsInc :+ (filed + "Inc")
    }
    /** 合并features **/
    val assembler = new VectorAssembler().setInputCols(Array("top1","top2")).setOutputCol("features")
    dataProcessList = dataProcessList :+ assembler

//    /** model **/
    val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01).setFeaturesCol("features").setLabelCol("output1").setPredictionCol("predict")

    val gbt = new GBTClassifier()
      .setLabelCol("output1")
      .setFeaturesCol("features")
      .setPredictionCol("gbt_prediction")
      .setProbabilityCol("gbt_prediction_prob")
      .setRawPredictionCol("gbt_prediction_raw")
      .setMaxBins(80)
      .setMaxIter(50)

    dataProcessList =  dataProcessList :+ lr :+ gbt
    val pipeline = new Pipeline().setStages(dataProcessList.toArray)
    val model = pipeline.fit(trainSet)

    val ds = model.transform(trainSet)
    implicit val context = SparkBundleContext().withDataset(ds)
/** 模型存储**/

    for(bf <- managed(BundleFile("jar:file:/tmp/lc.model.zip"))){
      model.writeBundle.save(bf)(context).get
    }
 }

2.2 预测(单条测试、批量测试)

注意输入model中transform的数据格式。

package com.sf


import ml.combust.bundle.BundleFile
import ml.combust.mleap.core.types.{ScalarType, StructField, StructType}
import ml.combust.mleap.runtime.MleapSupport._
import ml.combust.mleap.runtime.frame.{DefaultLeapFrame, Row}
import ml.combust.mleap.spark.SparkLeapFrame
import ml.combust.mleap.spark.SparkSupport._
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext}
import resource._


object PredictLeap {
  def getModelProbBatch(frame: SparkLeapFrame) = {
    val bundle = (for (bundleFile <- managed(BundleFile("jar:file:/tmp/lc.model.zip"))) yield {
      bundleFile.loadMleapBundle().get
    }).opt.get

    val model = bundle.root
    val df = model.transform(frame).get
    df
  }

  def getModelProbOne(data:Seq[Row]) = {
    val bundle = (for (bundleFile <- managed(BundleFile("jar:file:/tmp/lc.model.zip"))) yield {
      bundleFile.loadMleapBundle().get
    }).opt.get

    val schema = StructType(StructField("top1", ScalarType.Double),
      StructField("top2", ScalarType.Double),
      StructField("new_cargo", ScalarType.String),
      StructField("oprid", ScalarType.String)).get


    val frame = DefaultLeapFrame(schema, data)
    val model = bundle.root
    val df = model.transform(frame).get
    df
  }

  def main(args: Array[String]): Unit = {
    /** 单条预测 **/
    val dataOne: Seq[Row] = Seq(Row(1.0, 1.0, "other", "1247226"))

    val resOne = getModelProbOne(dataOne)
    resOne.show(5)

    
    /** dataframe批量预测 **/
    val sparkConf = new SparkConf().setMaster("local[2]").setAppName("zhunshibao_test")
    val sc = new SparkContext(sparkConf)
    val spark = SparkSession.builder().config(sc.getConf).config("hive.metastore.uris", "thrift://10.202.77.200:9083").enableHiveSupport().getOrCreate()

    var dataBatch = spark.sql(constantConfig.testTrainDataSetSql).na.drop
    val splited = dataBatch.randomSplit(Array(0.8, 0.2), 2L)
    val testSet = splited(1)
    
    val resBatch = PredictLeap.getModelProbBatch(testSet.toSparkLeapFrame)
    resBatch.toSpark.toDF().show(5)
  }

}

参考文档:
https://www.bookstack.cn/read/mleap-zh/mleap-runtime-create-leap-frame.md

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值