object GamingTrain { case class DataFormat(label: Double, text: String) def main(args: Array[String]) { val spark = SparkSession .builder .master("local") .appName("NaiveBayesExample") .getOrCreate() import spark.implicits._ val gamingPhrase = spark.read.textFile("D:\htmlf\\ml\\train").map(x => { val data = x.split(",") DataFormat(data(0).toDouble,data(1)) }); val Array(trainData, testData) = gamingPhrase.randomSplit(Array(0.5, 0.5), seed = 12347890L) val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words") val wordsData = tokenizer.transform(gamingPhrase) val hashTF = new HashingTF().setInputCol("words").setOutputCol("TF") val tfData = hashTF.transform(wordsData) val idf = new IDF().setInputCol("TF").setOutputCol("features") val idfModel = idf.fit(tfData) val idfData = idfModel.transform(tfData) val bayesData = idfData.select($"label",$"features") .map{ case Row(label: Double, features: Vector) => LabeledPoint(label, Vectors.dense(features.toArray)) } val model = new NaiveBayes().fit(bayesData) // /** // * 测试 // */ val tword = tokenizer.transform(testData) val tstf= hashTF.transform(tword) val tidfModel = idf.fit(tstf) val tidfDate = tidfModel.transform(tstf) val predictions = model.transform(tidfDate) predictions.show() // Select (prediction, true label) and compute test error val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("label") .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) println("Test set accuracy = " + accuracy) // model.write.overwrite().save("hdfs://10.0.0.88:9000/model/domain_analysis_model") model.write.overwrite().save("D:\\anyao\\htmlf\\ml\\model") }
spark ml贝叶斯建模
最新推荐文章于 2024-04-01 09:51:49 发布