spark ML 机器学习包的使用

  val spark = SparkSession.builder().config(new SparkConf().setMaster("local[*]")).getOrCreate()
    val training = spark.createDataFrame(Seq(
      (0L, "a b c d e spark", 1.0),
      (1L, "b d", 0.0),
      (2L, "spark f g h", 1.0),
      (3L, "hadoop mapreduce", 0.0),
      (4L, "b spark who", 1.0),
      (5L, "g d a y", 0.0),
      (6L, "spark fly", 1.0),
      (7L, "was mapreduce", 0.0),
      (8L, "e spark program", 1.0),
      (9L, "a e c l", 0.0),
      (10L, "spark compile", 1.0),
      (11L, "hadoop software", 0.0)
    )).toDF("id", "text", "label")

    // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
    val tokenizer = new Tokenizer()
      .setInputCol("text")
      .setOutputCol("words")
    val hashingTF = new HashingTF()
      .setInputCol(tokenizer.getOutputCol)
      .setOutputCol("features")
    val lr = new LogisticRegression()
      .setMaxIter(10)


    // val lrStartTime = new Date().getTime

    val pipeline = new Pipeline()
      .setStages(Array(tokenizer, hashingTF, lr))

    // 我们使用ParamGridBuilder来构建要搜索的参数网格。
    // 使用hashingTF.numFeatures的3个值和lr.regParam的2个值, regParam正则化参数
    // 此网格将有3 x 2 = 6个参数设置供CrossValidator选择。

    //    交叉验证参数设定和模型
    val paramGrid = new ParamGridBuilder()
      .addGrid(hashingTF.numFeatures, Array(10, 100, 1000))
      .addGrid(lr.regParam, Array(0.1, 0.01))
      .addGrid(lr.elasticNetParam,Array(0.1,0.0))
      .build()

    // 模型选择与调参的三个基本组件分别是 Estimator、ParamGrid 和 Evaluator,
    // 其中Estimator包括算法或者Pipeline;
    // ParamGrid即ParamMap集合,提供参数搜索空间;
    // Evaluator 即评价指标。


    // 我们现在将Pipeline视为Estimator,将其包装在CrossValidator实例中。
    // 这将允许我们共同选择所有Pipeline阶段的参数。
    // CrossValidator需要Estimator,一组Estimator ParamMaps和一个Evaluator。
    // 请注意,此处的求值程序是BinaryClassificationEvaluator,其默认度量
    // 是areaUnderROC。
    // 交叉验证
    // BinaryClassificationEvaluator 二值数据的评估
    // RegressionEvaluator   回归
    // MulticlassClassificationEvaluator 多分类
    val cv = new CrossValidator()
      .setEstimator(pipeline)  //要优化的pipeline
      .setEvaluator(new BinaryClassificationEvaluator)  // 评价指标
      .setEstimatorParamMaps(paramGrid)   // 参数搜索
      .setNumFolds(2)  //   使用几折交叉验证

    //运行交叉验证,并选择最佳参数集。
    val cvModel = cv.fit(training)
    val bestLrModel = cvModel.bestModel.asInstanceOf[PipelineModel]
    val bestHash = bestLrModel.stages(1).asInstanceOf[HashingTF]
    val bestHashFeature = bestHash.getNumFeatures
    val bestLr = bestLrModel.stages(2).asInstanceOf[LogisticRegressionModel]
    val blr = bestLr.getRegParam
    val ble = bestLr.getElasticNetParam
    println(s"HashTF最优参数:\ngetNumFeatures= $bestHashFeature \n逻辑回归模型最优参数:\nregParam = $blr,elasticNetParam = $ble")

    // Prepare test documents, which are unlabeled (id, text) tuples.
    val test = spark.createDataFrame(Seq(
      (4L, "spark i j k","1.0"),
      (5L, "l m n","0.0"),
      (6L, "mapreduce spark","1.0"),
      (7L, "apache hadoop","0.0")
    )).toDF("id", "text","label")

    cvModel.transform(test)
      .select("id", "text", "probability", "prediction")
      .collect()
      .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
        println(s"($id, $text) --> prob=$prob, prediction=$prediction")
      }

输出
{
logreg_ced5bcd864cb-elasticNetParam: 0.1,
hashingTF_d06823caef37-numFeatures: 100,
logreg_ced5bcd864cb-regParam: 0.1
}
是由配置日志文件
log4j.logger.org.apache.spark.ml.tuning.TrainValidationSplit=INFO
log4j.logger.org.apache.spark.ml.tuning.CrossValidator=INFO
所生成。

HashTF最优参数:
getNumFeatures= 100
逻辑回归模型最优参数:
regParam = 0.1,elasticNetParam = 0.1

(4, spark i j k) --> prob=[0.18790027301642065,0.8120997269835794], prediction=1.0
(5, l m n) --> prob=[0.8895957990095681,0.11040420099043186], prediction=0.0
(6, mapreduce spark) --> prob=[0.33625307291668444,0.6637469270833155], prediction=1.0
(7, apache hadoop) --> prob=[0.706788417403727,0.293211582596273], prediction=0.0

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值