spark PIPELINE 的交叉验证

 

1.创建CrossValidator

 

val crossval = new CrossValidator()
  .setEstimator(pipeline)
  .setEvaluator(new BinaryClassificationEvaluator)

 

第一个是pipeline程序

第二个是模型的求值函数

 

2.pipeline以前构造过,这里省过

 

3.这里值得学习的是ML api提供了一个帮助我们寻找最佳参数的api

val paramGrid = new ParamGridBuilder()
  .addGrid(hashingTF.numFeatures, Array(10, 100, 1000))
  .addGrid(lr.regParam, Array(0.1, 0.01))
  .build()

我们通过addGrid添加我们需要寻找的最佳参数

接着往crossval添加

crossval.setEstimatorParamMaps(paramGrid)
crossval.setNumFolds(2) // Use 3+ in practice

 

// Run cross-validation, and choose the best set of parameters.
val cvModel = crossval.fit(training.toDF())

 

4.验证

cvModel.transform(test.toDF())
  .select("id", "text", "probability", "prediction")
  .collect()
  .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
  println(s"($id, $text) --> prob=$prob, prediction=$prediction")
}

 

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值