import org.apache.spark.ml.feature.{Word2Vec, Word2VecModel}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
/**
* 从句子训练得到Embedding
*/
object TextEmbedding {
val embeddingSize = 3
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("TextEmbedding")
.master("local[2]")
.getOrCreate()
val df = loadText(spark)
val model = trianByWord2Vec(spark, df)
saveModel(spark, model, args(0))
saveTextEmb(spark, model, df, args(1))
saveWordEmb(spark, model, args(2))
loadModel(spark, args(0))
}
// 示例
def loadText(spark: SparkSession): DataFrame = {
val df = spark.createDataFrame(Seq(
(0, Array("Hi",