使用scala做机器学习模型主要步骤示例

本文介绍使用scala做机器学习模型的一个主要步骤示例。这里主要列了些基本环节,可以在此基础上进行扩充。

object mlExample {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("TobyGao")
      .enableHiveSupport()
      .getOrCreate()
    val modelPath = "/user/Tobygao/model_saved"
    val dataPath = "/user/Tobygao/ml_data"

    //1- load data
    var df = spark.read.json(dataPath+"/data/simple-ml")

    //2- train/test Split 
    val Array(train, test) = df.randomSplit(Array(0.7, 0.3))
   
   
    //3 featureVector -- VectorAssember or RFormula
    val rForm = new RFormula()
   
     //4 define model
    import org.apache.spark.ml.classification.LogisticRegression
    val lr = new LogisticRegression()
            .setLabelCol("label")
            .setFeaturesCol("features")
    println(lr.explainParams())

    //5- pipeline
    import org.apache.spark.ml.Pipeline
    val stages = Array(rForm, lr)
    val pipeline = new Pipeline().setStages(stages)

    //6 - ParamGridBuilder 参数构造器
    import org.apache.spark.ml.tuning.ParamGridBuilder
    val params = new ParamGridBuilder()
      .addGrid(rForm.formula, Array(
        "lab ~ . + color:value1",
        "lab ~ . + color:value1 + color:value2"))
      .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
      .addGrid(lr.regParam, Array(0.1, 2.0))
      .build()

    //7 - Evaluator
    import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
    val evaluator = new BinaryClassificationEvaluator()
      .setMetricName("areaUnderROC") //AUC
      .setRawPredictionCol("prediction")
      .setLabelCol("label")

    //8 - TrainValidationSplit
    import org.apache.spark.ml.tuning.TrainValidationSplit
    val tvs = new TrainValidationSplit()
      .setTrainRatio(0.75) // 训练集、验证集拆分比例
      .setEstimatorParamMaps(params) //参数网格构造器
      .setEstimator(pipeline) //估计器
      .setEvaluator(evaluator)  //评价器

    //9 - model fit
    val tvsFitted = tvs.fit(train)

    //10 - model predict
    val tvsPredict = tvsFitted.transform(test)
    tvsPredict.show()

    //11- show evaluate
    tvsPredict.show()
    println(evaluator.evaluate(tvsPredict)) // AUC

    //12- best model
    import org.apache.spark.ml.PipelineModel
    import org.apache.spark.ml.classification.LogisticRegressionModel
    val trainedPipeline = tvsFitted.bestModel.asInstanceOf[PipelineModel]
    val TrainedLR = trainedPipeline.stages(1).asInstanceOf[LogisticRegressionModel]
    val summaryLR = TrainedLR.summary
    println(summaryLR.objectiveHistory.mkString(",")) // 查看模型收敛速度,这是个Array,存放了每次迭代训练后的目标函数objective的值。可以通过查看这个数据来判断是否应该增大训练的迭代次数、早停或者调参优化模型

    //13 - model save
    tvsFitted.write.overwrite().save(modelPath+"/tmp/modelLocation")

    //14 - model load
    import org.apache.spark.ml.tuning.TrainValidationSplitModel
    val model = TrainValidationSplitModel.load(modelPath+"/tmp/modelLocation")
    model.transform(test)

  }
}

结果:

1-load data
+-----+----+------+------------------+
|color| lab|value1|            value2|
+-----+----+------+------------------+
|green|good|     1|14.386294994851129|
|green| bad|    16|14.386294994851129|
| blue| bad|     8|14.386294994851129|
| blue| bad|     8|14.386294994851129|
| blue| bad|    12|14.386294994851129|
|green| bad|    16|14.386294994851129|
|green|good|    12|14.386294994851129|
|  red|good|    35|14.386294994851129|
|  red|good|    35|14.386294994851129|
|  red| bad|     2|14.386294994851129|
|  red| bad|    16|14.386294994851129|
|  red| bad|    16|14.386294994851129|
| blue| bad|     8|14.386294994851129|
|green|good|     1|14.386294994851129|
|green|good|    12|14.386294994851129|
| blue| bad|     8|14.386294994851129|
|  red|good|    35|14.386294994851129|
| blue| bad|    12|14.386294994851129|
|  red| bad|    16|14.386294994851129|
|green|good|    12|14.386294994851129|
+-----+----+------+------------------+ 

3-RFormula
+-----+----+------+------------------+--------------------+-----+
|color| lab|value1|            value2|            features|label|
+-----+----+------+------------------+--------------------+-----+
|green|good|     1|14.386294994851129|(10,[1,2,3,5,8],[...|  1.0|
| blue| bad|     8|14.386294994851129|(10,[2,3,6,9],[8....|  0.0|
| blue| bad|    12|14.386294994851129|(10,[2,3,6,9],[12...|  0.0|
|green|good|    15| 38.97187133755819|(10,[1,2,3,5,8],[...|  1.0|
|green|good|    12|14.386294994851129|(10,[1,2,3,5,8],[...|  1.0|
|green| bad|    16|14.386294994851129|(10,[1,2,3,5,8],[...|  0.0|
|  red|good|    35|14.386294994851129|(10,[0,2,3,4,7],[...|  1.0|
|  red| bad|     1| 38.97187133755819|(10,[0,2,3,4,7],[...|  0.0|
|  red| bad|     2|14.386294994851129|(10,[0,2,3,4,7],[...|  0.0|
|  red| bad|    16|14.386294994851129|(10,[0,2,3,4,7],[...|  0.0|
|  red|good|    45| 38.97187133755819|(10,[0,2,3,4,7],[...|  1.0|
|green|good|     1|14.386294994851129|(10,[1,2,3,5,8],[...|  1.0|
| blue| bad|     8|14.386294994851129|(10,[2,3,6,9],[8....|  0.0|
| blue| bad|    12|14.386294994851129|(10,[2,3,6,9],[12...|  0.0|
|green|good|    15| 38.97187133755819|(10,[1,2,3,5,8],[...|  1.0|
|green|good|    12|14.386294994851129|(10,[1,2,3,5,8],[...|  1.0|
|green| bad|    16|14.386294994851129|(10,[1,2,3,5,8],[...|  0.0|
|  red|good|    35|14.386294994851129|(10,[0,2,3,4,7],[...|  1.0|
|  red| bad|     1| 38.97187133755819|(10,[0,2,3,4,7],[...|  0.0|
|  red| bad|     2|14.386294994851129|(10,[0,2,3,4,7],[...|  0.0|
+-----+----+------+------------------+--------------------+-----+ 

10- model prediction
+-----+----+------+------------------+--------------------+-----+--------------------+--------------------+----------+
|color| lab|value1|            value2|            features|label|       rawPrediction|         probability|prediction|
+-----+----+------+------------------+--------------------+-----+--------------------+--------------------+----------+
| blue| bad|     8|14.386294994851129|(7,[2,3,6],[8.0,1...|  0.0|[1.81841935188104...|[0.86037635368405...|       0.0|
| blue| bad|     8|14.386294994851129|(7,[2,3,6],[8.0,1...|  0.0|[1.81841935188104...|[0.86037635368405...|       0.0|
| blue| bad|     8|14.386294994851129|(7,[2,3,6],[8.0,1...|  0.0|[1.81841935188104...|[0.86037635368405...|       0.0|
| blue| bad|     8|14.386294994851129|(7,[2,3,6],[8.0,1...|  0.0|[1.81841935188104...|[0.86037635368405...|       0.0|
| blue| bad|     8|14.386294994851129|(7,[2,3,6],[8.0,1...|  0.0|[1.81841935188104...|[0.86037635368405...|       0.0|
| blue| bad|    12|14.386294994851129|(7,[2,3,6],[12.0,...|  0.0|[2.15923553226233...|[0.89652865416576...|       0.0|
| blue| bad|    12|14.386294994851129|(7,[2,3,6],[12.0,...|  0.0|[2.15923553226233...|[0.89652865416576...|       0.0|
| blue| bad|    12|14.386294994851129|(7,[2,3,6],[12.0,...|  0.0|[2.15923553226233...|[0.89652865416576...|       0.0|
|green| bad|    16|14.386294994851129|[0.0,1.0,16.0,14....|  0.0|[-0.6607070390540...|[0.34058080292169...|       1.0|
|green| bad|    16|14.386294994851129|[0.0,1.0,16.0,14....|  0.0|[-0.6607070390540...|[0.34058080292169...|       1.0|
|green| bad|    16|14.386294994851129|[0.0,1.0,16.0,14....|  0.0|[-0.6607070390540...|[0.34058080292169...|       1.0|
|green|good|     1|14.386294994851129|[0.0,1.0,1.0,14.3...|  1.0|[-0.4860199728364...|[0.38083160751668...|       1.0|
|green|good|     1|14.386294994851129|[0.0,1.0,1.0,14.3...|  1.0|[-0.4860199728364...|[0.38083160751668...|       1.0|
|green|good|    12|14.386294994851129|[0.0,1.0,12.0,14....|  1.0|[-0.6141238213959...|[0.35111907339348...|       1.0|
|green|good|    12|14.386294994851129|[0.0,1.0,12.0,14....|  1.0|[-0.6141238213959...|[0.35111907339348...|       1.0|
|green|good|    15| 38.97187133755819|[0.0,1.0,15.0,38....|  1.0|[-1.1954765118448...|[0.23228089715736...|       1.0|
|green|good|    15| 38.97187133755819|[0.0,1.0,15.0,38....|  1.0|[-1.1954765118448...|[0.23228089715736...|       1.0|
|  red| bad|     1| 38.97187133755819|[1.0,0.0,1.0,38.9...|  0.0|[1.34210087888720...|[0.79283521791024...|       0.0|
|  red| bad|     1| 38.97187133755819|[1.0,0.0,1.0,38.9...|  0.0|[1.34210087888720...|[0.79283521791024...|       0.0|
|  red| bad|     2|14.386294994851129|[1.0,0.0,2.0,14.3...|  0.0|[1.80828707711963...|[0.85915472458301...|       0.0|
+-----+----+------+------------------+--------------------+-----+--------------------+--------------------+----------+ 

11 - evaluator
AUC 0.9210526315789473 

12 - objectiveHistory
0.6930670630541909,0.5961979573995868,0.5268590504745408,0.47249879942722717,0.4635853436671372,0.4548517016401766,0.4501786908817158,0.44601002558944336,0.4436249409133597,0.4416524078673581,0.4415889464730704,0.44157753890376344,0.44157750876351043,0.441577487735776,0.44157748049670753,0.4415774778777291,0.44157747771911865,0.44157747771846245

13 - model save
12.0 K  /user/gaoToby/model_saved/tmp/modelLocation/bestModel
1.1 K   /user/gaoToby/model_saved/tmp/modelLocation/estimator
376     /user/gaoToby/model_saved/tmp/modelLocation/evaluator
3.7 K   /user/gaoToby/model_saved/tmp/modelLocation/metadata
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值