scala做embedding的average操作
使用 breeze.linalg 来对embedding向量处理
breeze.linalg 库可以对矩阵向量做很多操作,普通的加减乘除,点乘叉乘,都能支持
import breeze.linalg.DenseVector
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.apache.spark.storage.StorageLevel
/**
* scala做embedding的average操作
*/
object EmbeddingAvg {
val embeddingSize = 3
def avgAction(spark: SparkSession, emb: DataFrame, history: DataFrame): DataFrame = {
import spark.implicits._
val defaultEmbSize = embeddingSize
println("defaultEmbSize:" + defaultEmbSize)
val embMap = emb.collect()
.map(row => {
val key = row.getAs[String]("id")
val value = row.getAs[Vector]("vector")
(key, value)
}).toMap
val result = history.map(row => {
val key = row.getAs[String]("key")
val list = row.getAs[String]("list")
.split(",")
.map(x => {
if (embMap.contains(x)) {
embMap(x)
} else {
Vectors.zeros(defaultEmbSize)
}
})
val head = list.head
var res = new DenseVector[Double](head.toArray)
list.tail.foreach(v => {
val ve = new DenseVector[Double](v.toArray)
res = res + ve
})
// avg 操作
res = res *:* (1.0 / list.length)
(key, res.toArray.mkString(","))
}).toDF("key", "emb")
.persist(StorageLevel.MEMORY_AND_DISK_SER)
println("result count:" + result.count())
result.show(false)
result
}
}
breeze.linalg 库可参考网上 https://www.cnblogs.com/itboys/p/10594039.html