package com.xtf.demo.mllib
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.clustering.{DistributedLDAModel, LDA, LocalLDAModel}
import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
object ContentLDA {
Logger.getLogger("org").setLevel(Level.ERROR)
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("test")
.master("local[*]").getOrCreate()
val dataDF = spark.createDataFrame(Seq(
("aaa", Array("a", "b", "c", "d", "e", "spark", "spark")),
("bbb", Array("b", "d")),
("ccc", Array("spark", "f", "g", "h")),
("ddd", Array("hadoop", "mapreduce"))
)).toDF("id", "words")
val hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures")
val hashingDF = hashingTF.transform(dataDF)
hashingDF.show(false)
//生成自增索引列,method one
// val indexDF = dataDF.withColumn("index",
// row_number().over(Window.orderBy("words")).cast("double"))
// indexDF.show(false)
//method two
val indexRDD = hashingDF.rdd.zipWithIndex()
val mergeRdd = indexRDD.map { case(row, lon) =>
Row.merge(row, Row(lon.toDouble))
}
val newSchema = hashingDF.schema.add("index", DoubleType)
val mergeDF = spark.createDataFrame(mergeRdd, newSchema)
mergeDF.show(false)
//optimier: 优化计算方法,目前支持"em","online"
//分布式模型可以转换成本地模型,反之不行
//EMLDAOptimizer会产生DistributedLDAModel,它不只存储推断的主题,还有所有训练预料,
//以及训练预料库中每个文档的主题分布
//OnlineLDAOptimizer生成LocalLADModel,它只存储了推断的主题
val lda = new LDA()
.setK(2)
.setMaxIter(10)
.setFeaturesCol("rawFeatures")
.setOptimizer("em")
val ldaModel = lda.fit(mergeDF)
ldaModel.write.overwrite().save("./ldaModel")
val distributedLDAModel = DistributedLDAModel.load("./ldaModel")
// val localLDAModel = LocalLDAModel.load("./ldaModel")
val df = distributedLDAModel.transform(mergeDF)
df.show(false)
val sedf = df.select("id", "topicDistribution").rdd.map{
x => (x.getAs[String]("id"), x.getAs[DenseVector]("topicDistribution"))
}.map{case(id: String, vector: DenseVector) =>
val arr = vector.toArray.zipWithIndex.sortBy(-_._1).take(1).filter(_._1 > 0.5)
.map(x => id + "|" + x._2 + "|" + x._1.formatted("%.3f"))
arr.mkString("#")
}
println(sedf.collect().toBuffer)
spark.stop()
}
}
控制台打印:
+---+-----------------------------+--------------------------------------------------------------------------+
|id |words |rawFeatures |
+---+-----------------------------+--------------------------------------------------------------------------+
|aaa|[a, b, c, d, e, spark, spark]|(262144,[17222,27526,28698,30913,227410,234657],[1.0,1.0,1.0,1.0,1.0,2.0])|
|bbb|[b, d] |(262144,[27526,30913],[1.0,1.0]) |
|ccc|[spark, f, g, h] |(262144,[15554,24152,51505,234657],[1.0,1.0,1.0,1.0]) |
|ddd|[hadoop, mapreduce] |(262144,[42633,155117],[1.0,1.0]) |
+---+-----------------------------+--------------------------------------------------------------------------+
+---+-----------------------------+--------------------------------------------------------------------------+-----+
|id |words |rawFeatures |index|
+---+-----------------------------+--------------------------------------------------------------------------+-----+
|aaa|[a, b, c, d, e, spark, spark]|(262144,[17222,27526,28698,30913,227410,234657],[1.0,1.0,1.0,1.0,1.0,2.0])|0.0 |
|bbb|[b, d] |(262144,[27526,30913],[1.0,1.0]) |1.0 |
|ccc|[spark, f, g, h] |(262144,[15554,24152,51505,234657],[1.0,1.0,1.0,1.0]) |2.0 |
|ddd|[hadoop, mapreduce] |(262144,[42633,155117],[1.0,1.0]) |3.0 |
+---+-----------------------------+--------------------------------------------------------------------------+-----+
+---+-----------------------------+--------------------------------------------------------------------------+-----+----------------------------------------+
|id |words |rawFeatures |index|topicDistribution |
+---+-----------------------------+--------------------------------------------------------------------------+-----+----------------------------------------+
|aaa|[a, b, c, d, e, spark, spark]|(262144,[17222,27526,28698,30913,227410,234657],[1.0,1.0,1.0,1.0,1.0,2.0])|0.0 |[0.5066506944443651,0.49334930555563494]|
|bbb|[b, d] |(262144,[27526,30913],[1.0,1.0]) |1.0 |[0.5024641283101848,0.4975358716898151] |
|ccc|[spark, f, g, h] |(262144,[15554,24152,51505,234657],[1.0,1.0,1.0,1.0]) |2.0 |[0.5083923945368911,0.49160760546310894]|
|ddd|[hadoop, mapreduce] |(262144,[42633,155117],[1.0,1.0]) |3.0 |[0.49363189084987436,0.5063681091501256]|
+---+-----------------------------+--------------------------------------------------------------------------+-----+----------------------------------------+
ArrayBuffer(aaa|0|0.507, bbb|0|0.502, ccc|0|0.508, ddd|1|0.506)
关注微信公众号【飞哥大数据】,回复666 获取2022年100+公司面试真题,以及spark与flink面试题汇总
1万+

被折叠的 条评论
为什么被折叠?



