1.创建CrossValidator
val crossval = new CrossValidator() .setEstimator(pipeline) .setEvaluator(new BinaryClassificationEvaluator)
第一个是pipeline程序
第二个是模型的求值函数
2.pipeline以前构造过,这里省过
3.这里值得学习的是ML api提供了一个帮助我们寻找最佳参数的api
val paramGrid = new ParamGridBuilder() .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) .addGrid(lr.regParam, Array(0.1, 0.01)) .build()
我们通过addGrid添加我们需要寻找的最佳参数
接着往crossval添加
crossval.setEstimatorParamMaps(paramGrid) crossval.setNumFolds(2) // Use 3+ in practice
// Run cross-validation, and choose the best set of parameters. val cvModel = crossval.fit(training.toDF())
4.验证
cvModel.transform(test.toDF()) .select("id", "text", "probability", "prediction") .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => println(s"($id, $text) --> prob=$prob, prediction=$prediction") }