在笔记本跑了一个简单的贝叶斯分类示例,工程级的代码原理类似,只不过有些细节需要修改。
主要代码如下:
import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.ml.feature.{HashingTF, _} import org.apache.hadoop.fs.Path import org.apache.spark.ml.classification.NaiveBayes import org.apache.spark.ml.feature._ object bayes { def main(args: Array[String]) { val spark = SparkSession .builder .appName("bayes") .getOrCreate() import spark.implicits._ val sentenceDataFrame = spark.createDataFrame(Seq( //比较简单的样本数据 0分类 水果; 1分类 粮食 (0,"水果","苹果 橘子 香蕉"), (1, "粮食","大米 小米 土豆") )).toDF("label","category", "text") val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words") var wordData = tokenizer.transform(sentenceDataFrame) val stopwordFile: String = "/applications/stopWords" //引入停用词 val customizedStopWords: Array[String] = if (stopwordFile.isEmpty()) { Array.empty[String] } else { val stopWordText = spark.read.text(stopwordFile).as[String].collect() stopWordText.flatMap(_.stripMargin.split("\\s+")) } val stopWordsRemover = new StopWordsRemover() .setInputCol("words") .setOutputCol("token") stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords) var wordDataWithOutStopWord = stopWordsRemover.transform(wordData) var hashingTF = new HashingTF() .setInputCol("token").setOutputCol("tf") val tf= hashingTF.transform(wordDataWithOutStopWord) tf.cache() tf.show(false) val idf=new IDF().setInputCol("tf").setOutputCol("features").fit(tf) //根据以上数据训练的idf模型,实际需要根据大量数据训练 val tfidf =idf.transform(tf) tfidf.show(false) val naiveBayesModel = new NaiveBayes() //创建贝叶斯模型,用上面数据训练 .setSmoothing(1) .fit(tfidf) val training = spark.createDataFrame(List( //待预测的测试数据 (0, "大米") )).toDF("id", "text") var tokenfeature = tokenizer.transform(training) wordDataWithOutStopWord = stopWordsRemover.transform(tokenfeature) var trainRescaledData = hashingTF.transform(wordDataWithOutStopWord) val tfidf1 = idf.transform(trainRescaledData) val predictions = naiveBayesModel .transform(tfidf1) predictions.printSchema() val predict = predictions.first().getAs[Double]("prediction") //预测结果 输出label 为1 粮食分类 println("predict aaaaa:") println(predict) spark.stop() } }