1、代码
import org.apache.spark.sql.SparkSession
import toby.gao.config.modelConfig
/**
* scala - recommendation 推荐系统
* package : org.apache.spark.ml.recommendation
* 方法: ALS 、 ALSModel 交替最小二乘法
*/
object example28 {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("recommendation")
.enableHiveSupport()
.getOrCreate()
// 1 - load data
val ratings = spark.read.textFile(modelConfig.dataPath + "/data/sample_movielens_ratings.txt")
.selectExpr("split(value , '::') as col")
.selectExpr(
"cast(col[0] as int) as userId",
"cast(col[1] as int) as movieId",
"cast(col[2] as float) as rating",
"cast(col[3] as long) as timestamp") //时间戳
ratings.cache()
ratings.show()
// 2- train /test split
val Array(training, test) = ratings.randomSplit(Array(0.8, 0.2))
// 3 - model ALS 交替最小二乘法
import org.apache.spark.ml.recommendation.ALS
val als = new ALS()
.setMaxIter(5)
.setRegParam(0.01)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating")
println("Toby:" + als.explainParams())
// 4 - model fit
val alsModel = als.fit(training)
//5 - model predict
val predictions = alsModel.transform(test)
predictions.cache()
predictions.show()
// 6 - model evaluate
import org.apache.spark.ml.evaluation.RegressionEvaluator
val evaluator = new RegressionEvaluator()
.setMetricName("rmse")
.setLabelCol("rating")
.setPredictionCol("prediction")
val rmse = evaluator.evaluate(predictions)
println(s"Toby : Root-mean-square error = $rmse")
// 6 - model metrics
import org.apache.spark.mllib.evaluation.RegressionMetrics
val regComparison = predictions.select("rating", "prediction")
.rdd.map(x => (x.getFloat(0).toDouble,x.getFloat(1).toDouble))
val metrics = new RegressionMetrics(regComparison)
println("Toby: " + metrics.r2 )
println("Toby: " + metrics.meanSquaredError)
println("Toby: " + metrics.explainedVariance)
// 7 - ranking metrics 推荐的Top候选项的准确性评价
// 7- 1 获取用户真实场景中喜欢的电影集合(这里设置阈值 rating > 2.5 为喜欢)
import org.apache.spark.mllib.evaluation.RankingMetrics
import org.apache.spark.sql.functions.{col, expr}
val perUserActual = predictions
.where("rating > 2.5")
.groupBy("userId")
.agg(expr("collect_set(movieId) as movies"))
perUserActual.cache()
perUserActual.show()
// COMMAND ----------
// 7-2 将预测的电影按照预测值从高到底排序聚合
val perUserPredictions = predictions
.orderBy(col("userId"), col("prediction").desc)
.groupBy("userId")
.agg(expr("collect_list(movieId) as movies"))
perUserPredictions.cache()
perUserPredictions.show()
// COMMAND ----------
// 7-3 给用户推荐的电影列表在用户真实喜欢的TOP电影列表中命中率多少
import spark.implicits._
val perUserActualvPred = perUserActual.join(perUserPredictions, Seq("userId"))
.map(row => (
row(1).asInstanceOf[Seq[Integer]].toArray,
row(2).asInstanceOf[Seq[Integer]].toArray.take(15)
))
perUserActualvPred.cache()
perUserActualvPred.show()
val ranks = new RankingMetrics(perUserActualvPred.rdd)
// COMMAND ----------
// 7-4 查看Top准确率
println("Toby: " + ranks.meanAveragePrecision ) //平均准确率
println("Toby: " + ranks.precisionAt(5)) //Top准确率
}
}
2、结果
1 - load data +------+-------+------+----------+ |userId|movieId|rating| timestamp| +------+-------+------+----------+ | 0| 2| 3.0|1424380312| | 0| 3| 1.0|1424380312| | 0| 5| 2.0|1424380312| | 0| 9| 4.0|1424380312| | 0| 11| 1.0|1424380312| | 0| 12| 2.0|1424380312| | 0| 15| 1.0|1424380312| | 0| 17| 1.0|1424380312| | 0| 19| 1.0|1424380312| | 0| 21| 1.0|1424380312| | 0| 23| 1.0|1424380312| | 0| 26| 3.0|1424380312| | 0| 27| 1.0|1424380312| | 0| 28| 1.0|1424380312| | 0| 29| 1.0|1424380312| | 0| 30| 1.0|1424380312| | 0| 31| 1.0|1424380312| | 0| 34| 1.0|1424380312| | 0| 37| 1.0|1424380312| | 0| 41| 2.0|1424380312| +------+-------+------+----------+ 3 - ALS model params 见Scala文档: http://spark.apache.org/docs/2.3.0/api/scala/index.html#org.apache.spark.mllib.recommendation.ALS 5 - model prediction +------+-------+------+----------+----------+ |userId|movieId|rating| timestamp|prediction| +------+-------+------+----------+----------+ | 26| 85| 1.0|1424380312| 1.0803492| | 28| 85| 1.0|1424380312| 3.692985| | 8| 85| 5.0|1424380312| 3.2585845| | 28| 94| 1.0|1424380312| 2.6824238| | 27| 57| 1.0|1424380312| 1.6582177| | 19| 65| 1.0|1424380312|-1.2475832| | 27| 27| 3.0|1424380312| -2.898336| | 1| 19| 1.0|1424380312| 2.0594523| | 24| 19| 1.0|1424380312| 2.08305| | 11| 19| 4.0|1424380312| 2.0348344| | 17| 19| 1.0|1424380312| 4.782441| | 0| 72| 1.0|1424380312| 0.7783812| | 18| 70| 1.0|1424380312|-1.1651708| | 3| 70| 1.0|1424380312|-0.3618136| | 4| 70| 4.0|1424380312| 1.5592415| | 11| 70| 1.0|1424380312|-1.2107608| | 5| 70| 1.0|1424380312| 1.499864| | 15| 98| 3.0|1424380312| 2.0334811| | 14| 98| 1.0|1424380312| 3.1419742| | 9| 98| 1.0|1424380312|-1.0441759| +------+-------+------+----------+----------+ 6 - RMSE Toby : Root-mean-square error = 1.78834564765203 6 - R2 \ MSE \ explainedVariance Toby: -0.6495449885926385 Toby: 3.1981801554759586 Toby: 1.6947224080258068 7-1 - 用户喜欢的电影列表 +------+--------------------+ |userId| movies| +------+--------------------+ | 27| [66, 27, 18]| | 11|[48, 81, 30, 19, ...| | 2| [34]| | 16| [5, 90]| | 9| [49, 2, 90, 43]| | 22| [51, 30, 75, 87]| | 8| [52, 85, 31, 95]| | 21| [96, 53, 58, 29]| | 26| [68, 21, 18]| | 1| [21, 77]| | 6| [63, 25, 61]| | 17| [90]| | 15| [98]| | 4| [60, 70, 29]| | 25| [33, 12, 71]| | 7| [29]| | 20| [88, 90]| | 23| [87]| | 12| [16, 31, 23]| | 5| [64]| +------+--------------------+ 7-2 给用户推荐的电影列表 +------+--------------------+ |userId| movies| +------+--------------------+ | 27|[27, 25, 22, 31, ...| | 19|[95, 55, 62, 58, ...| | 11|[16, 64, 69, 48, ...| | 2|[52, 34, 15, 10, ...| | 29| [33, 79, 62, 3]| | 16|[38, 90, 99, 56, ...| | 9|[22, 36, 71, 90, ...| | 0|[69, 96, 12, 48, ...| | 22|[51, 26, 30, 45, ...| | 8|[52, 7, 92, 88, 5...| | 3| [35, 70, 33, 9]| | 21|[96, 30, 29, 0, 8...| | 26|[21, 85, 18, 62, ...| | 1|[44, 82, 97, 21, ...| | 6|[63, 88, 9, 25, 4...| | 28|[98, 85, 6, 58, 9...| | 17|[19, 60, 29, 90, ...| | 15|[33, 29, 37, 98, ...| | 4|[70, 8, 50, 71, 2...| | 25|[56, 64, 33, 31, ...| +------+--------------------+ 7-3 喜欢的电影 和推荐的电影 +--------------------+--------------------+ | _1| _2| +--------------------+--------------------+ | [66, 27, 18]|[27, 25, 22, 31, ...| |[48, 81, 30, 19, ...|[16, 64, 69, 48, ...| | [34]|[52, 34, 15, 10, ...| | [5, 90]|[38, 90, 99, 56, ...| | [49, 2, 90, 43]|[22, 36, 71, 90, ...| | [51, 30, 75, 87]|[51, 26, 30, 45, ...| | [52, 85, 31, 95]|[52, 7, 92, 88, 5...| | [96, 53, 58, 29]|[96, 30, 29, 0, 8...| | [68, 21, 18]|[21, 85, 18, 62, ...| | [21, 77]|[44, 82, 97, 21, ...| | [63, 25, 61]|[63, 88, 9, 25, 4...| | [90]|[19, 60, 29, 90, ...| | [98]|[33, 29, 37, 98, ...| | [60, 70, 29]|[70, 8, 50, 71, 2...| | [33, 12, 71]|[56, 64, 33, 31, ...| | [29]|[15, 49, 33, 29, ...| | [88, 90]|[12, 32, 30, 88, ...| | [87]|[0, 69, 87, 77, 1...| | [16, 31, 23]|[23, 86, 53, 24, ...| | [64]|[64, 38, 70, 31, ...| +--------------------+--------------------+ 7-4 平均准确率: 0.2770502645502645 Top准确率: 0.5166666666666667