1.源码组成
spark LDA模型的调用接口放在了ml库中,org/apache/spark/ml/clustering/LDA.scala。
spark实现的LDA模型在apache/spark/tree/branch-2.1/mllib/src/main/scala/org/apache/spark/mllib/clustering路径中。以下是其源码组成:
2.训练使用
我们先来看个简单的应用例子:
import org.apache.spark.ml.clustering.LDA
// Loads data.
val dataset = spark.read.format("libsvm")
.load("data/mllib/sample_lda_libsvm_data.txt")
// Trains a LDA model.
//设置主题数目K,最大迭代次数
val lda = new LDA().setK(10).setMaxIter(10)
val model = lda.fit(dataset)
val ll = model.logLikelihood(dataset)
val lp = model.logPerplexity(dataset)
println(s"The lower bound on the log likelihood of the entire corpus: $ll")
println(s"The upper bound on perplexity: $lp")
// Describe topics.
val topics = model.describeTopics(3)
println("The topics described by their top-weighted terms:")
topics.show(false)
// Shows the result.
val transformed = model.transform(dataset)
transformed.show(false)
3. 源码分析
调用接口是在ml库中,在LDA.scala代码中,定义了LDA类。该类继承自ml库中的Estimator类,指定LDAModel,该LDAModel来自于mllib中。
class LDA @Since("1.6.0") (
@Since("1.6.0") override val uid: String)
extends Estimator[LDAModel] with LDAParams with DefaultParamsWritable {
//训练方法
@Since("2.0.0")
override def fit(dataset: Dataset[_]): LDAModel = {
transformSchema(dataset.schema, logging = true)
//调用mllib中的实现代码
val oldLDA &