主题分析模型LDA的spark实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/ZH519080/article/details/85010327

主体分析模型主要有PLSA(Probabilistic Latent Semantic Anlysis,概率引语义分析)和LDA(Latent Dirichlet Allocation,隐含狄利克雷分布),在此暂时介绍LDA的spark实现。

    * 主题分析模型自动分析每个文档,统计文档内的词语,根据统计的信息来判断当前文档含有
    * 哪些主题,以及每个主题所占的比例格式多少。
    * 可将LDA的主题分析结果进一步用在数据挖掘任务中的,其任务大概有:
    * 1、主题推断。给定一篇新的文档x,利用已有的主题模型训练的结果,计算出文档x所包含的主题,
    * 以及各个主题的比重。
    * 2、文档聚类。主题可看作是聚类中心,而文档可看作是多个聚类中心相关联的数据样本,
    * 利用主题模型做文档聚类,可用于重新组织文档数据集。
    * 3、特征选择。由于主题模型可推断出每个文档在不同主题上的分布,因此这个分布可看作文档的一个新特征,
    * 该特征可用于其他的机器学习模型中。
    * 4、降维。主题模型得到的主题分布特征,可看作将原来的高纬度文档向量投影到低纬度主题空间中。

import org.apache.spark.ml.clustering.{LDA, LDAModel}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, HashingTF, Tokenizer}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}

/**
  * author: ZH519080@163.com  朱海川
  * date_time: 2017/10/13 14:31
  * illustration: 
  */
class LDAThemeAnalysis {

  private def data(sparkContext: SparkContext): DataFrame ={

    val sqlContext = new SQLContext(sparkContext)
    import sqlContext.implicits._
    val data = sparkContext.parallelize(Seq(
      (1,Array("祖国","万岁")),
      (2,Array("中华人民共和国","雄起")),
      (3,Array("万岁","中国")),
      (4,Array("祖国","雄起")),
      (5,Array("中华","雄起")),
      (6,Array("雄起")))).map{x =>
      (x._1,x._2)
    }.toDF("id","context")
    data
  }

  private def valueCompute(dataFrame: DataFrame): Unit ={

    val cv = new CountVectorizer()
      .setInputCol("context").setOutputCol("features")
    val cvmodel: CountVectorizerModel = cv.fit(dataFrame)
    val cvResult: DataFrame = cvmodel.transform(dataFrame)
    /**获得转成向量时词表*/
    val vocabulary = cvmodel.vocabulary

    /**setK:主题(聚类中心)个数
      * setMaxIter:最大迭代次数
      * setOptimizer:优化计算方法,支持”em“和”online“
      * setDocConcentration:文档-主题分布的堆成先验Dirichlet参数,值越大,分布越平滑,值>1.0
      * setTopicConcentration:文档-词语分布的先验Dirichlet参数,值越大,分布越平滑,值>1.0
      *setCheckpointInterval:checkpoint的检查间隔
      * */
    val lda = new LDA()
      .setK(3)
      .setMaxIter(20)
        .setOptimizer("em")
      .setDocConcentration(2.2)
      .setTopicConcentration(1.5)

    val ldamodel: LDAModel = lda.fit(cvResult)
    /**可能度*/
    ldamodel.logLikelihood(cvResult)
    /**困惑度,困惑度越小,模型训练越好*/
    ldamodel.logPerplexity(cvResult)

    val ladmodel: DataFrame = ldamodel.transform(cvResult)
    ladmodel.foreach(println)

  }

}

object LDAThemeAnalysis {
  def main(args: Array[String]): Unit = {

    val sparkConf = new SparkConf().setAppName("lda theme analysis")
    val sparkContext = SparkContext.getOrCreate(sparkConf)

    val lDAThemeAnalysis = new LDAThemeAnalysis
    val dataFrame = lDAThemeAnalysis.data(sparkContext)
    lDAThemeAnalysis.valueCompute(dataFrame)
    println("。。。。。。。。。 我很高兴啊 。。。。。。。。。")
  }
}

执行结果:

 

展开阅读全文

没有更多推荐了,返回首页