机器学习线上预测流程和方案

 我们知道模型通过离线方式训练出来后,怎么进行预测呢?如何在线上实时打分。

​ 我们可以将模型结果转为一种叫pmml的格式文件,然后交由openscoring的Web 服务组件来进行线上打分预测。openscoring是一套解决方案(官网是openscoring.io),github上有一位叫vruusmann的大神基于opencoring的组件,实现了一套REST Web服务,用于R,Scikit-Learn和Apache Spark模型的真实实时评分,名字还是叫openscoring,据说耗时小于1ms。

openscoring REST WEB服务代码

Fork 路劲:https://github.com/dearbaba/openscoring

原路径:https://github.com/openscoring/openscoring

一、部署OpenScoreing服务

 
  1. java -Dconfig.file=application.conf -Djava.util.logging.config.file=logging.properties -jar openscoring-server-executable-${version}.jar

推送模型:

 
  1. #lrmodel即为模型的ID
  2. curl -X PUT --data-binary @lrmodel.pmml -H "Content-type: text/xml" http://localhost:8080/openscoring/model/lrmodel

预测打分:

 
  1. curl -X POST --data-binary @lrmodel.json -H "Content-type: application/json" http://localhost:8080/openscoring/model/lrmodel

具体其它接口请参看GIT文档。

二、模型结果

例如,我们的模型是使用spark ML将模型结果保存为一个pmml的格式文件,大致代码如下:

 
  1. import org.apache.spark.ml.feature.VectorIndexer
  2. import org.apache.spark.SparkConf
  3. import org.apache.spark.sql.SparkSession
  4. import org.apache.spark.ml.regression.DecisionTreeRegressor
  5. import org.apache.spark.ml.Pipeline
  6. import org.apache.spark.ml.evaluation.RegressionEvaluator
  7. import org.apache.spark.ml.regression.DecisionTreeRegressionModel
  8. import org.apache.spark.ml.classification.DecisionTreeClassifier
  9. import org.apache.spark.ml.classification.DecisionTreeClassificationModel
  10. import org.apache.spark.storage.StorageLevel
  11. import org.apache.spark.ml.linalg.{ Vector, Vectors }
  12. import org.apache.commons.lang3.StringUtils
  13. import scala.util.parsing.json.JSONObject
  14. import java.io.FileReader
  15. import scala.io.Source
  16. import scala.util.parsing.json.JSON
  17. import spray.json.JsObject
  18. import org.json4s.jackson.Json
  19. import org.json4s._
  20. import org.json4s.JsonDSL._
  21. import org.json4s.jackson.JsonMethods._
  22. import org.apache.spark.util.LongAccumulator
  23. import cn.pa18.spark.util.DBConnectionDao
  24. import org.apache.spark.sql.Row
  25. import org.apache.spark.broadcast.Broadcast
  26. import scala.util.control.Breaks._
  27. import scala.collection.mutable.ArrayBuffer
  28. import org.apache.spark.ml.feature.CountVectorizer
  29. import org.apache.spark.ml.classification.LogisticRegression
  30. import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
  31. import org.apache.spark.ml.feature.RegexTokenizer
  32. import org.apache.commons.lang3.math.NumberUtils
  33. import com.sun.tools.example.debug.expr.ExpressionParser
  34. import scala.tools.reflect.Eval
  35. import scala.reflect.runtime.currentMirror
  36. import scala.tools.reflect.ToolBox
  37. import org.apache.spark.ml.feature.CountVectorizerModel
  38. import org.apache.spark.ml.classification.LogisticRegressionModel
  39. import spray.json.JsValue
  40. import spray.json.JsNumber
  41. import java.text.SimpleDateFormat
  42. import java.util.Calendar
  43. import org.apache.spark.sql.SaveMode
  44. import java.util.Date
  45. import org.apache.spark.sql.Dataset
  46. import org.apache.spark.ml.PipelineModel
  47. import org.jpmml.sparkml.PMMLBuilder
  48. import java.io.File
  49. import org.apache.hadoop.fs.Path
  50.  
  51. def lr_training(hiveContext: SparkSession,appid: String): Unit = {
  52. import hiveContext.implicits._
  53.  
  54. val data = hiveContext.sql(s"""
  55. select userid,label,b.tagmap from table_name
  56. """.stripMargin).rdd.map(
  57. row => {
  58. val features = ArrayBuffer[String]()
  59. val tdid = row.getAs[String]("userid")
  60. val label = row.getAs[Int]("label") //正负样本标志
  61. val tagmap = row.getAs[scala.collection.immutable.Map[String, String]]("tagmap")
  62. if (tagmap.size > 0) {
  63. for (key lrModel.coefficients(i))
  64. }
  65. kvWeightData += ("intercept_b" -> lrModel.intercept)
  66. kvWeightData += ("auc" -> auc)
  67.  
  68. //将特征权重保存至hdfs文件
  69. val dateFormat = new SimpleDateFormat("yyyyMMdd")
  70. val todayDate = dateFormat.format(new Date())
  71. val feature_importance_save_path = hdfs_uri+s"/model/result/${todayDate}/"
  72. //保存至HDFS文件
  73. hiveContext.sparkContext.parallelize(kvWeightData.toArray[(String,Double)]).map(row => row._1 + "\t" + row._2).repartition(1).saveAsTextFile(feature_importance_save_path)
  74.  
  75. //将模型结果保存至PMML文件路径
  76. val hdfs_path = hdfs_uri+s"/data/spark/rym/models/"
  77. //将模型结果保存至PMML文件的名称
  78. val hdfs_file_name = "lrmodel.pmml"
  79. //将模型结果保存至HDFS
  80. save_to_PMML(trainingDF, pipelineModel, hdfs_path, hdfs_file_name)
  81. }
  82.  
  83. /**
  84. * 将模型保存至PMML文件
  85. */
  86. def save_to_PMML(trainingDF: Dataset[Row], pipelineModel: PipelineModel, hdfs_path: String, hdfs_file_name: String): Unit = {
  87. println("start to save model to pmml file ... ...")
  88. val pmmlBuilder = new PMMLBuilder(trainingDF.schema, pipelineModel)
  89. pmmlBuilder.buildFile(new File(hdfs_file_name))
  90.  
  91. val hdfs = org.apache.hadoop.fs.FileSystem.get(new java.net.URI(hdfs_uri), new org.apache.hadoop.conf.Configuration())
  92. if (!HDFSHelper.exists(hdfs, hdfs_path)) {
  93. HDFSHelper.createFolder(hdfs, hdfs_path)
  94. }
  95. val path = new Path(hdfs_file_name)
  96.  
  97. val dst_path = new Path(hdfs_path)
  98. hdfs.copyFromLocalFile(path, dst_path)
  99. }
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值