scala做embedding的average操作

15 篇文章 0 订阅
13 篇文章 0 订阅

scala做embedding的average操作

使用 breeze.linalg 来对embedding向量处理

breeze.linalg 库可以对矩阵向量做很多操作,普通的加减乘除,点乘叉乘,都能支持

import breeze.linalg.DenseVector
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.apache.spark.storage.StorageLevel

/**
  * scala做embedding的average操作
  */
object EmbeddingAvg {
  val embeddingSize = 3

  def avgAction(spark: SparkSession, emb: DataFrame, history: DataFrame): DataFrame = {
    import spark.implicits._
    val defaultEmbSize = embeddingSize
    println("defaultEmbSize:" + defaultEmbSize)
    val embMap = emb.collect()
      .map(row => {
        val key = row.getAs[String]("id")
        val value = row.getAs[Vector]("vector")
        (key, value)
      }).toMap

    val result = history.map(row => {
      val key = row.getAs[String]("key")
      val list = row.getAs[String]("list")
        .split(",")
        .map(x => {
          if (embMap.contains(x)) {
            embMap(x)
          } else {
            Vectors.zeros(defaultEmbSize)
          }
        })

      val head = list.head
      var res = new DenseVector[Double](head.toArray)
      list.tail.foreach(v => {
        val ve = new DenseVector[Double](v.toArray)
        res = res + ve
      })

	// avg 操作
      res = res *:* (1.0 / list.length)

      (key, res.toArray.mkString(","))
    }).toDF("key", "emb")
      .persist(StorageLevel.MEMORY_AND_DISK_SER)

    println("result count:"  + result.count())
    result.show(false)
    result
  }

}

breeze.linalg 库可参考网上 https://www.cnblogs.com/itboys/p/10594039.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值