【Spark】Spark训练Lr模型,并保存为Pmml

scala版本spark构建的Lr模型:

一、问题背景

  需要构建一个Lr模型来进行物品的Ctr预测。

二、解决方案

  由于我们训练的数据量较多,所以首先考虑采用spark来构建模型并测试训练,这样的效率较高。
  *模型接口详情可以参考spark的scala的API文档:https://spark.apache.org/docs/latest/api/scala/org/apache/spark/index.html,整体代码如下:

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel, VectorAssembler}
import org.apache.spark.sql.SparkSession

import org.jpmml.model.JAXBUtil
import org.jpmml.sparkml.PMMLBuilder
import javax.xml.transform.stream.StreamResult

object CargoClinchLR {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().enableHiveSupport().getOrCreate()
	//原始数据字段
    val path = "hdfs://xxxxxxxx"
    //离散特征字段
	val str_col = Array( "start_city_id", "end_city_id", "start_prov_id", "end_prov_id",
      ……)

    val idx_col = for (c <- str_col) yield s"${c}_idx"
    //连续特征字段
    val num_col = Array("weight", "capacity", "distance")
    val data = spark.read.parquet(path).na.fill("unknown").na.replace(str_col, Map("" -> "unknown"))
    //划分训练集和测试集
    val Array(train, test) = data.randomSplit(Array(0.8, 0.2))
	//这里的特征是进行索引编码的,并没进行onehot操作(常规的lr是进行onehot操作)
    val str_idxers = for (c <- str_col)
      yield new StringIndexer().setInputCol(c).setOutputCol(s"${c}_idx").setHandleInvalid("skip").setStringOrderType("frequencyAsc")
	//离散特征向量与连续特征向量的拼接
    val assember = new VectorAssembler().setInputCols(idx_col ++ num_col).setOutputCol("fea")
    //spark中逻辑回归的模型
    val lr = new LogisticRegression().setFeaturesCol("fea").setLabelCol("label")
    val pip = new Pipeline().setStages(str_idxers ++ Array(assember, lr)).fit(train)
	// pip.write.overwrite().save("XXX").  //保存pip模型
    // val sameModel = PipelineModel.load("/home/a1022856/CargoClinch/spark-logistic-regression-model")                 // 模型导入 

	//模型实例化,获取模型参数
    val lr_model = pip.stages.last.asInstanceOf[LogisticRegressionModel]
    //获取模型的权重
    println(lr_model.coefficients)
	
	//相关二分类相应指标统计值
    val summary = lr_model.binarySummary
    val precision = summary.weightedPrecision
    val recall = summary.weightedRecall
    val accuracy = summary.accurac
    val auc = summary.areaUnderROC
    //打印auc参数
    println(s"train_acc =${auc}")
    
    //二分类的模型评估器,以auc作为评估指标
    val eval = new BinaryClassificationEvaluator().setLabelCol("label")
      .setMetricName("areaUnderROC")
    eval.evaluate(pip.transform(test))
    
	val auc = eval.evaluate(pip.transform(test))
    println(s"eval_acc =${auc}")

    // 保存pmml文件,需要下载对应的pmml转换的包
    val input_col = str_col ++ num_col
    val pmml = new PMMLBuilder(data.schema, pip).build()
	JAXBUtil.marshalPMML(pmml, new StreamResult("model"))
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

郝同学

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值