Spark MLlib基于LDA算法的主题分析,scala代码

package com.xtf.demo.mllib

import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.clustering.{DistributedLDAModel, LDA, LocalLDAModel}
import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType

object ContentLDA {
    Logger.getLogger("org").setLevel(Level.ERROR)

    def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder().appName("test")
            .master("local[*]").getOrCreate()
        val dataDF = spark.createDataFrame(Seq(
            ("aaa", Array("a", "b", "c", "d", "e", "spark", "spark")),
            ("bbb", Array("b", "d")),
            ("ccc", Array("spark", "f", "g", "h")),
            ("ddd", Array("hadoop", "mapreduce"))
        )).toDF("id", "words")
        val hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures")
        val hashingDF = hashingTF.transform(dataDF)
        hashingDF.show(false)

        //生成自增索引列,method one
        //        val indexDF = dataDF.withColumn("index",
        //            row_number().over(Window.orderBy("words")).cast("double"))
        //        indexDF.show(false)

        //method two
        val indexRDD = hashingDF.rdd.zipWithIndex()
        val mergeRdd = indexRDD.map { case(row, lon) =>
            Row.merge(row, Row(lon.toDouble))
        }
        val  newSchema = hashingDF.schema.add("index", DoubleType)
        val mergeDF = spark.createDataFrame(mergeRdd, newSchema)
        mergeDF.show(false)

        //optimier: 优化计算方法,目前支持"em","online"
        //分布式模型可以转换成本地模型,反之不行
        //EMLDAOptimizer会产生DistributedLDAModel,它不只存储推断的主题,还有所有训练预料,
        //以及训练预料库中每个文档的主题分布
        //OnlineLDAOptimizer生成LocalLADModel,它只存储了推断的主题
        val lda = new LDA()
                .setK(2)
                .setMaxIter(10)
                .setFeaturesCol("rawFeatures")
                .setOptimizer("em")
        val ldaModel = lda.fit(mergeDF)

        ldaModel.write.overwrite().save("./ldaModel")
        val distributedLDAModel = DistributedLDAModel.load("./ldaModel")
//        val localLDAModel = LocalLDAModel.load("./ldaModel")

        val df = distributedLDAModel.transform(mergeDF)
        df.show(false)

        val sedf = df.select("id", "topicDistribution").rdd.map{
            x => (x.getAs[String]("id"), x.getAs[DenseVector]("topicDistribution"))
        }.map{case(id: String, vector: DenseVector) =>
            val arr = vector.toArray.zipWithIndex.sortBy(-_._1).take(1).filter(_._1 > 0.5)
                .map(x => id + "|" + x._2 + "|" + x._1.formatted("%.3f"))
                arr.mkString("#")
        }
        println(sedf.collect().toBuffer)

        spark.stop()
    }

}

控制台打印:

+---+-----------------------------+--------------------------------------------------------------------------+
|id |words                        |rawFeatures                                                               |
+---+-----------------------------+--------------------------------------------------------------------------+
|aaa|[a, b, c, d, e, spark, spark]|(262144,[17222,27526,28698,30913,227410,234657],[1.0,1.0,1.0,1.0,1.0,2.0])|
|bbb|[b, d]                       |(262144,[27526,30913],[1.0,1.0])                                          |
|ccc|[spark, f, g, h]             |(262144,[15554,24152,51505,234657],[1.0,1.0,1.0,1.0])                     |
|ddd|[hadoop, mapreduce]          |(262144,[42633,155117],[1.0,1.0])                                         |
+---+-----------------------------+--------------------------------------------------------------------------+

+---+-----------------------------+--------------------------------------------------------------------------+-----+
|id |words                        |rawFeatures                                                               |index|
+---+-----------------------------+--------------------------------------------------------------------------+-----+
|aaa|[a, b, c, d, e, spark, spark]|(262144,[17222,27526,28698,30913,227410,234657],[1.0,1.0,1.0,1.0,1.0,2.0])|0.0  |
|bbb|[b, d]                       |(262144,[27526,30913],[1.0,1.0])                                          |1.0  |
|ccc|[spark, f, g, h]             |(262144,[15554,24152,51505,234657],[1.0,1.0,1.0,1.0])                     |2.0  |
|ddd|[hadoop, mapreduce]          |(262144,[42633,155117],[1.0,1.0])                                         |3.0  |
+---+-----------------------------+--------------------------------------------------------------------------+-----+

+---+-----------------------------+--------------------------------------------------------------------------+-----+----------------------------------------+
|id |words                        |rawFeatures                                                               |index|topicDistribution                       |
+---+-----------------------------+--------------------------------------------------------------------------+-----+----------------------------------------+
|aaa|[a, b, c, d, e, spark, spark]|(262144,[17222,27526,28698,30913,227410,234657],[1.0,1.0,1.0,1.0,1.0,2.0])|0.0  |[0.5066506944443651,0.49334930555563494]|
|bbb|[b, d]                       |(262144,[27526,30913],[1.0,1.0])                                          |1.0  |[0.5024641283101848,0.4975358716898151] |
|ccc|[spark, f, g, h]             |(262144,[15554,24152,51505,234657],[1.0,1.0,1.0,1.0])                     |2.0  |[0.5083923945368911,0.49160760546310894]|
|ddd|[hadoop, mapreduce]          |(262144,[42633,155117],[1.0,1.0])                                         |3.0  |[0.49363189084987436,0.5063681091501256]|
+---+-----------------------------+--------------------------------------------------------------------------+-----+----------------------------------------+

ArrayBuffer(aaa|0|0.507, bbb|0|0.502, ccc|0|0.508, ddd|1|0.506)

 关注微信公众号【飞哥大数据】,回复666 获取2022年100+公司面试真题,以及spark与flink面试题汇总

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值