HitRecall的scala实现
topK推荐中常用评估指标就有HitRecall,这个指标的定义如下
HR@K = NumbersOfHit@K / GT
分母是所有测试集集合,分子是topK推荐中命中测试集的总和
事先将用户向量和物料向量导出来,假设格式一致
userid1||0.03388985991477966,-0.05629376694560051,-0.28580108284950256,-0.03301679715514183,-0.00348331220448017,0.23206464946269989,-0.07762476801872253,0.12137320637702942
userid2||0.02444665506482124,0.17300844192504883,-0.17057773470878601,-0.12084504216909409,-0.04795640334486961,0.18982534110546112,-0.10042955726385117,0.10137677937746048
加载用户向量方式如下
def loadUserEmb(ss: SparkSession, path: String): Dataset[(String, DenseVector[Double])] = {
import ss.implicits._
implicit val denseVectorEncoder = org.apache.spark.sql.Encoders.kryo[(String, breeze.linalg.DenseVector[Double])]
ss.read.text(path).map(x => {
val line = x.getString(0).split("\\|\\|")
val key = line(0)
val value = line(1)
(key, value)
})
.toDF("userid", "emb")
.map(row => {
val userId = row.getAs[String]("userid")
val ve = DenseVector(row.getAs[String]("emb").split(",").map(_.toDouble))
(userId, ve)
})(denseVectorEncoder)
}
无量向量相对较少,可以直接broadcast出去
itemid1||0.03388985991477966,-0.05629376694560051,-0.28580108284950256,-0.03301679715514183,-0.00348331220448017,0.23206464946269989,-0.07762476801872253,0.12137320637702942
itemid2||0.02444665506482124,0.17300844192504883,-0.17057773470878601,-0.12084504216909409,-0.04795640334486961,0.18982534110546112,-0.10042955726385117,0.10137677937746048
def loadItemEmbBc(ss: SparkSession, path: String): Broadcast[Map[String, DenseVector[Double]]] = {
import ss.implicits._
implicit val denseVectorEncoder = org.apache.spark.sql.Encoders.kryo[(String, breeze.linalg.DenseVector[Double])]
val ds = ss.read.text(path).map(x => {
val line = x.getString(0).split("\\|\\|")
val key = line(0)
val value = line(1)
(key, value)
}).toDF("itemid", "emb")
.map(row => {
val itemId = row.getAs[String]("itemid")
val ve = DenseVector(row.getAs[String]("emb").split(",").map(_.toDouble))
(itemId, ve)
})(denseVectorEncoder)
.persist(StorageLevel.MEMORY_AND_DISK)
println("item emb count:" + ds.count())
ss.sparkContext.broadcast(ds.collect().toMap)
}
加载样本,即真实测试集
item1_user5|0|86755830550:20:0.298238;116052872301:27:1.0;1251978090330:291:1.0986122886681098
item2_user2|1|86755830550:20:0.298238;116052872301:27:1.0;1251978090330:291:1.0986122886681098
item1_user2|1|86755830550:20:0.298238;116052872301:27:1.0;1251978090330:291:1.0986122886681098
def loadSample(ss: SparkSession, path: String): DataFrame = {
import ss.implicits._
val pathList = listPath(ss, path)
val sample = ss.read.option("sep", "|").csv(pathList: _*)
.toDF("info", "label", "sample")
.filter(col("label") === "1")
.map(row => {
val info = row.getAs[String]("info").split("_")
val itemId = info.head
val userId = info.last
val label = row.getAs[String]("label")
(userId, itemId)
}).toDF("userid", "itemid")
.groupBy("userid")
.agg(collect_list("itemid").as("item_list"))
.select(col("userid"), concat_ws(",", col("item_list")).as("item_list"))
sample
}
计算每个用户向量的topN相似物料
def calcTopK(ss: SparkSession, userEmb: Dataset[(String, DenseVector[Double])],
itemEmbBc: Broadcast[Map[String, DenseVector[Double]]], simFun: String = "cos"): DataFrame = {
import ss.implicits._
userEmb.repartition(2000).map(row => {
val uid = row._1
val ve = row._2
val top50 = calcSimilarity(ve, itemEmbBc.value, simFun).map(x => x._1)).mkString(",")
(uid, top50)
}).toDF("userid", "top50")
}
def calcSimilarity(user: DenseVector[Double], candidates: Map[String, DenseVector[Double]], simFun: String = "cos"): Array[(String, Double)] = {
candidates.map(i => {
val candidate = i._2
val itemId = i._1
val score = cosSim(user, candidate, simFun)
(itemId, score)
}).toArray
.sortBy(-_._2)
}
def cosSim(v1: DenseVector[Double], v2: DenseVector[Double], simFun: String = "cos"): Double = {
if (simFun == "dot") {
v1 dot v2
} else {
val v0 = v1 dot v2
v0 / (norm(v1) * norm(v2))
}
}
根据用户向量计算的top相似物料向量,与真实测试集进行对比,计算HitRecall
def calcHitRatio(ss: SparkSession, df: DataFrame, sample: DataFrame): Unit = {
import ss.implicits._
val hitDf = df.join(sample, Seq("userid"))
.map(row => {
val userId = row.getAs[String]("userid")
val itemSet = row.getAs[String]("item_list").split(",").toSet
val count = fset.size.toDouble
val top50 = row.getAs[String]("top50").split(",")
val hitTop3 = top50.take(3).toSet.intersect(itemSet).size / count
val hitTop5 = top50.take(5).toSet.intersect(itemSet).size / count
val hitTop10 = top50.take(10).toSet.intersect(itemSet).size / count
val hitTop20 = top50.take(20).toSet.intersect(itemSet).size / count
val hitTop50 = top50.take(50).toSet.intersect(itemSet).size / count
(userId, count, hitTop3, hitTop5, hitTop10, hitTop20, hitTop50)
}).toDF("userid", "click_count", "Hit@3", "Hit@5", "Hit@10", "Hit@20", "Hit@50")
.persist(StorageLevel.MEMORY_AND_DISK)
hitDf.show(false)
hitDf.describe().show(1000, false)
}
初始化参数,主要结构
def init(args: Array[String]): Unit = {
val parser = new ArgParser(args)
userEmbPath = parser.getStringValue("userEmbPath", "")
itemEmbPath = parser.getStringValue("itemEmbPath", "")
samplePath = parser.getStringValue("samplePath", "")
simFunction = parser.getStringValue("simFunction", simFunction)
println(s"userEmbPath:$userEmbPath")
println(s"itemEmbPath:$itemEmbPath")
println(s"samplePath:$samplePath")
println(s"simFunction:$simFunction")
}
最终代码
import breeze.linalg.{DenseVector, norm}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel
import scala.collection.mutable.ArrayBuffer
object HitRecall {
var userEmbPath: String = ""
var itemEmbPath: String = ""
var samplePath: String = ""
var simFunction: String = "cos"
def main(args: Array[String]) {
val ss = SparkSession.builder.appName("HitRecall").getOrCreate
init(args)
val userEmb = loadUserEmb(ss, userEmbPath)
val itemEmbBc = loadItemEmbBc(ss, itemEmbPath)
val simDf = calcTopK(ss, userEmb, itemEmbBc, simFunction)
val sampleDf = loadSample(ss, samplePath)
calcHitRatio(ss, simDf, sampleDf)
}
}