HitRecall的scala实现

HitRecall的scala实现

topK推荐中常用评估指标就有HitRecall,这个指标的定义如下
HR@K = NumbersOfHit@K / GT
分母是所有测试集集合,分子是topK推荐中命中测试集的总和

事先将用户向量和物料向量导出来,假设格式一致

userid1||0.03388985991477966,-0.05629376694560051,-0.28580108284950256,-0.03301679715514183,-0.00348331220448017,0.23206464946269989,-0.07762476801872253,0.12137320637702942
userid2||0.02444665506482124,0.17300844192504883,-0.17057773470878601,-0.12084504216909409,-0.04795640334486961,0.18982534110546112,-0.10042955726385117,0.10137677937746048

加载用户向量方式如下

  def loadUserEmb(ss: SparkSession, path: String): Dataset[(String, DenseVector[Double])] = {
    import ss.implicits._
    implicit val denseVectorEncoder = org.apache.spark.sql.Encoders.kryo[(String, breeze.linalg.DenseVector[Double])]
    ss.read.text(path).map(x => {
      val line = x.getString(0).split("\\|\\|")
      val key = line(0)
      val value = line(1)
      (key, value)
    })
      .toDF("userid", "emb")
      .map(row => {
        val userId = row.getAs[String]("userid")
        val ve = DenseVector(row.getAs[String]("emb").split(",").map(_.toDouble))
        (userId, ve)
      })(denseVectorEncoder)
  }

无量向量相对较少,可以直接broadcast出去

itemid1||0.03388985991477966,-0.05629376694560051,-0.28580108284950256,-0.03301679715514183,-0.00348331220448017,0.23206464946269989,-0.07762476801872253,0.12137320637702942
itemid2||0.02444665506482124,0.17300844192504883,-0.17057773470878601,-0.12084504216909409,-0.04795640334486961,0.18982534110546112,-0.10042955726385117,0.10137677937746048

  def loadItemEmbBc(ss: SparkSession, path: String): Broadcast[Map[String, DenseVector[Double]]] = {
    import ss.implicits._
    implicit val denseVectorEncoder = org.apache.spark.sql.Encoders.kryo[(String, breeze.linalg.DenseVector[Double])]
    val ds = ss.read.text(path).map(x => {
      val line = x.getString(0).split("\\|\\|")
      val key = line(0)
      val value = line(1)
      (key, value)
    }).toDF("itemid", "emb")
      .map(row => {
        val itemId = row.getAs[String]("itemid")
        val ve = DenseVector(row.getAs[String]("emb").split(",").map(_.toDouble))
        (itemId, ve)
      })(denseVectorEncoder)
      .persist(StorageLevel.MEMORY_AND_DISK)
    println("item emb count:" + ds.count())

    ss.sparkContext.broadcast(ds.collect().toMap)
  }

加载样本,即真实测试集

item1_user5|0|86755830550:20:0.298238;116052872301:27:1.0;1251978090330:291:1.0986122886681098
item2_user2|1|86755830550:20:0.298238;116052872301:27:1.0;1251978090330:291:1.0986122886681098
item1_user2|1|86755830550:20:0.298238;116052872301:27:1.0;1251978090330:291:1.0986122886681098
def loadSample(ss: SparkSession, path: String): DataFrame = {
    import ss.implicits._
    val pathList = listPath(ss, path)
    val sample = ss.read.option("sep", "|").csv(pathList: _*)
      .toDF("info", "label", "sample")
      .filter(col("label") === "1")
      .map(row => {
        val info = row.getAs[String]("info").split("_")
        val itemId = info.head
        val userId = info.last
        val label = row.getAs[String]("label")
        (userId, itemId)
      }).toDF("userid", "itemid")
      .groupBy("userid")
      .agg(collect_list("itemid").as("item_list"))
      .select(col("userid"), concat_ws(",", col("item_list")).as("item_list"))
    sample
  }

计算每个用户向量的topN相似物料

  def calcTopK(ss: SparkSession, userEmb: Dataset[(String, DenseVector[Double])],
               itemEmbBc: Broadcast[Map[String, DenseVector[Double]]], simFun: String = "cos"): DataFrame = {
    import ss.implicits._
    userEmb.repartition(2000).map(row => {
      val uid = row._1
      val ve = row._2
      val top50 = calcSimilarity(ve, itemEmbBc.value, simFun).map(x => x._1)).mkString(",")
      (uid, top50)
    }).toDF("userid", "top50")
  }

  def calcSimilarity(user: DenseVector[Double], candidates: Map[String, DenseVector[Double]], simFun: String = "cos"): Array[(String, Double)] = {
    candidates.map(i => {
      val candidate = i._2
      val itemId = i._1
      val score = cosSim(user, candidate, simFun)
      (itemId, score)
    }).toArray
      .sortBy(-_._2)
  }

  def cosSim(v1: DenseVector[Double], v2: DenseVector[Double], simFun: String = "cos"): Double = {
    if (simFun == "dot") {
      v1 dot v2
    } else {
      val v0 = v1 dot v2
      v0 / (norm(v1) * norm(v2))
    }
  }

根据用户向量计算的top相似物料向量,与真实测试集进行对比,计算HitRecall

 def calcHitRatio(ss: SparkSession, df: DataFrame, sample: DataFrame): Unit = {
    import ss.implicits._
    val hitDf = df.join(sample, Seq("userid"))
      .map(row => {
        val userId = row.getAs[String]("userid")
        val itemSet = row.getAs[String]("item_list").split(",").toSet
        val count = fset.size.toDouble
        val top50 = row.getAs[String]("top50").split(",")
        val hitTop3 = top50.take(3).toSet.intersect(itemSet).size / count
        val hitTop5 = top50.take(5).toSet.intersect(itemSet).size / count
        val hitTop10 = top50.take(10).toSet.intersect(itemSet).size / count
        val hitTop20 = top50.take(20).toSet.intersect(itemSet).size / count
        val hitTop50 = top50.take(50).toSet.intersect(itemSet).size / count
        (userId, count, hitTop3, hitTop5, hitTop10, hitTop20, hitTop50)
      }).toDF("userid", "click_count", "Hit@3", "Hit@5", "Hit@10", "Hit@20", "Hit@50")
      .persist(StorageLevel.MEMORY_AND_DISK)

    hitDf.show(false)

    hitDf.describe().show(1000, false)
  }

初始化参数,主要结构

  def init(args: Array[String]): Unit = {
    val parser = new ArgParser(args)
    userEmbPath = parser.getStringValue("userEmbPath", "")
    itemEmbPath = parser.getStringValue("itemEmbPath", "")
    samplePath = parser.getStringValue("samplePath", "")
    simFunction = parser.getStringValue("simFunction", simFunction)
    println(s"userEmbPath:$userEmbPath")
    println(s"itemEmbPath:$itemEmbPath")
    println(s"samplePath:$samplePath")
    println(s"simFunction:$simFunction")
  }

最终代码

import breeze.linalg.{DenseVector, norm}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel

import scala.collection.mutable.ArrayBuffer
object HitRecall {
  var userEmbPath: String = ""
  var itemEmbPath: String = ""
  var samplePath: String = ""
  var simFunction: String = "cos"

  def main(args: Array[String]) {
    val ss = SparkSession.builder.appName("HitRecall").getOrCreate

    init(args)

    val userEmb = loadUserEmb(ss, userEmbPath)
    val itemEmbBc = loadItemEmbBc(ss, itemEmbPath)
    val simDf = calcTopK(ss, userEmb, itemEmbBc, simFunction)
    val sampleDf = loadSample(ss, samplePath)
    calcHitRatio(ss, simDf, sampleDf)
  }
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值