一个简单的spark贝叶斯分类程序

在笔记本跑了一个简单的贝叶斯分类示例,工程级的代码原理类似,只不过有些细节需要修改。

主要代码如下:

import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.ml.feature.{HashingTF, _}
import org.apache.hadoop.fs.Path
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.feature._

object bayes {
  def main(args: Array[String]) {
    val spark = SparkSession
      .builder
      .appName("bayes")
      .getOrCreate()

    import spark.implicits._

    val sentenceDataFrame = spark.createDataFrame(Seq(    //比较简单的样本数据 0分类 水果; 1分类 粮食
      (0,"水果","苹果 橘子 香蕉"),
      (1, "粮食","大米 小米 土豆")
    )).toDF("label","category", "text")

    val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
    var wordData = tokenizer.transform(sentenceDataFrame)

    val stopwordFile: String = "/applications/stopWords"        //引入停用词

    val customizedStopWords: Array[String] = if (stopwordFile.isEmpty()) {
      Array.empty[String]
    } else {
      val stopWordText = spark.read.text(stopwordFile).as[String].collect()
      stopWordText.flatMap(_.stripMargin.split("\\s+"))
    }

    val stopWordsRemover = new StopWordsRemover()
      .setInputCol("words")
      .setOutputCol("token")
    stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
    var wordDataWithOutStopWord = stopWordsRemover.transform(wordData)

    var hashingTF = new HashingTF()
      .setInputCol("token").setOutputCol("tf")
    val tf= hashingTF.transform(wordDataWithOutStopWord)
    tf.cache()
    tf.show(false)

    val idf=new IDF().setInputCol("tf").setOutputCol("features").fit(tf)  //根据以上数据训练的idf模型,实际需要根据大量数据训练
    val tfidf =idf.transform(tf)
    tfidf.show(false)


    val naiveBayesModel = new NaiveBayes()  //创建贝叶斯模型,用上面数据训练
      .setSmoothing(1)
      .fit(tfidf)

    val training = spark.createDataFrame(List(  //待预测的测试数据
      (0, "大米")
    )).toDF("id", "text")


    var tokenfeature = tokenizer.transform(training)
    wordDataWithOutStopWord = stopWordsRemover.transform(tokenfeature)

    var trainRescaledData = hashingTF.transform(wordDataWithOutStopWord)
    val tfidf1 = idf.transform(trainRescaledData)
    val predictions = naiveBayesModel
      .transform(tfidf1)
    predictions.printSchema()
    val predict = predictions.first().getAs[Double]("prediction")   //预测结果 输出label 为1 粮食分类

    println("predict aaaaa:")
    println(predict)


    spark.stop()

  }

}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值