【Spark MLlib】(三)Spark MLlib 数据基础_spark mlib中那个函数用于计算两个向量的点积

//正确率
val evaluator1 = new MulticlassClassificationEvaluator()
  .setLabelCol("indexedLabel")
  .setPredictionCol("prediction")
  .setMetricName("accuracy")
val accuracy = evaluator1.evaluate(predictions)
println(accuracy)
 
//f1
val evaluator2 = new MulticlassClassificationEvaluator()
  .setLabelCol("indexedLabel")
  .setPredictionCol("prediction")
  .setMetricName("f1")
val f1 = evaluator2.evaluate(predictions)
println(f1)
 
//Precision
val evaluator3 = new MulticlassClassificationEvaluator()
  .setLabelCol("indexedLabel")
  .setPredictionCol("prediction")
  .setMetricName("weightedPrecision")
val Precision = evaluator3.evaluate(predictions)
println(Precision)
 
//Recall
val evaluator4 = new MulticlassClassificationEvaluator()
  .setLabelCol("indexedLabel")
  .setPredictionCol("prediction")
  .setMetricName("weightedRecall")
val Recall = evaluator4.evaluate(predictions)
println(Recall)
 
//AUC
val evaluator5 = new BinaryClassificationEvaluator()
  .setLabelCol("indexedLabel")
  .setRawPredictionCol("prediction")
  .setMetricName("areaUnderROC")
val auc = evaluator5.evaluate(predictions)
println(auc)
 
//aupr
val evaluator6 = new BinaryClassificationEvaluator()
  .setLabelCol("indexedLabel")
  .setRawPredictionCol("prediction")
  .setMetricName("areaUnderPR")
val aupr = evaluator6.evaluate(predictions)
println(aupr)

三、交叉-验证方法

交叉验证法先将数据集D划分为k个大小相似的互斥子集,即D=D1并D2并…并Dk,每个子集之间没有交集。然后每次用k-1个子集的并集作为训练集,余下的那个作为测试集,这样得到k组训练/测试集。可以进行k次训练和测试,最终返回的是这个k个结果的均值。可以随机使用不同的划分多次,例如:10次10折交叉验证通常把交叉验证法称为“k折交叉验证”(k-fold cross validation),k最常用的取值时10,为10折交叉验证。

示例:交叉验证

package sparkml
 
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.Row
import org.apache.spark.ml.linalg.Vector
 
object JiaoChaYanZheng {
  def main(args: Array[String]): Unit = {
    //设置日志输出级别
    Logger.getLogger("org").setLevel(Level.WARN)
 
    //定义SparkSession
    val spark = SparkSession.builder()
      .appName("jcyz")
      .master("local[\*]")
      .getOrCreate()
 
    import spark.implicits._
 
    //样本数据,格式为(id, text, label)
    val training = spark.createDataFrame(Seq(
      (0L, "a b c d e spark", 1.0),
      (1L, "b d", 0.0),
      (2L, "spark f g h", 0.0),
      (3L, "hadoop mapreduce", 0.0),
      (4L, "b spark who", 1.0),
      (5L, "g d a y", 0.0),
      (6L, "spark fly", 1.0),
      (7L, "was mapreduce", 0.0),
      (8L, "e spark program", 1.0),
      (9L, "a e c l", 0.0),
      (10L, "spark compile", 1.0),
      (11L, "hadoop software", 0.0)
    )).toDF("id", "text", "label")
 
    //建立ML管道,包括:tokenizer,hashingTF,lr
    val tokenizer = new Tokenizer()
      .setInputCol("text")
      .setOutputCol("words")
    val hashingTF = new HashingTF()
      .setInputCol(tokenizer.getOutputCol)
      .setOutputCol("features")
    val lr = new LogisticRegression()
      .setMaxIter(10)
    val pipeline = new Pipeline()
      .setStages(Array(tokenizer, hashingTF, lr))
 
    //采用ParamGridBuilde方法来建立网格搜索
    //网格的参数包括:hashingTF.numFeatures 3个参数,lr.regParam 2个参数
    //网格总共大小为:3 \* 2 = 6,采用交叉验证来选择最优参数
    val paramGrid = new ParamGridBuilder()
      .addGrid(hashingTF.numFeatures, Array(10, 100, 1000))
      .addGrid(lr.regParam, Array(0.1, 0.01))
      .build()
 
    //建立一个交叉验证的评估器,设置评估的参数
    val cv = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(new BinaryClassificationEvaluator())
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(2)
 
    //运行交叉验证评估器,得到最佳参数集的模型
    val cvModel = cv.fit(training)
 
    //测试数据
    val test = spark.createDataFrame(Seq(
      (4L, "spark i j k"),
      (5L, "l m n"),
      (6L, "mapreduce spark"),
      (7L, "apache hadoop")
    )).toDF("id", "text")
 
    //测试,cvModel会选择最佳的lrModel进行预测
    val result = cvModel.transform(test)
    result.select("id", "text", "probability", "prediction")
      .collect()
      .foreach{
        case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
          println(s"($id, $text) --> prob = $prob, prediction = $prediction")
      }
  }
 
}

img
img

网上学习资料一大堆,但如果学到的知识不成体系,遇到问题时只是浅尝辄止,不再深入研究,那么很难做到真正的技术提升。

需要这份系统化资料的朋友,可以戳这里获取

一个人可以走的很快,但一群人才能走的更远!不论你是正从事IT行业的老鸟或是对IT行业感兴趣的新人,都欢迎加入我们的的圈子(技术交流、学习资源、职场吐槽、大厂内推、面试辅导),让我们一起学习成长!

要这份系统化资料的朋友,可以戳这里获取](https://bbs.csdn.net/topics/618545628)**

一个人可以走的很快,但一群人才能走的更远!不论你是正从事IT行业的老鸟或是对IT行业感兴趣的新人,都欢迎加入我们的的圈子(技术交流、学习资源、职场吐槽、大厂内推、面试辅导),让我们一起学习成长!

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值