http://spark.apache.org/docs/2.2.0/ml-collaborative-filtering.html
不需要用户和商品属性的信息,这类算法通常称为协同过滤算法
例子:根据两个用户的年龄相同来判断他们可能有相似的偏好,这不叫协同过滤。相反,根据两个用户播放过许多相同歌曲来判断他们可能都喜欢某首歌,这才叫协同过滤。
SparkMLlib 的ALS算法 要求用户和产品ID必须是数值型,这意味着大于Integer.MAX_VALUE(2147483647)的值都是非法的。
训练出的模型可以保存到文件,还可以从文件load模型
package test import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.ml.recommendation.{ALS, ALSModel} import org.apache.spark.ml.Model /** * Created by othc on 2018-01-19. */ object ALS1 { case class Rating(userId:Int,artistId:Int,count:Float) def main(args: Array[String]): Unit = { //session val spark = SparkSession.builder().config("spark.sql.warehouse.dir","/usr/local/testdata/spark-warehouse").appName("als").getOrCreate() import spark.implicits._ //用户id 艺术家id 次数 val rawUserArtstData: Dataset[String] = spark.read.textFile("/usr/local/mldata/user_artist_data.txt") //艺术家id 名字 val rawArtistData = spark.read.textFile("/usr/local/mldata/artist_data.txt") val artistById = rawArtistData.flatMap(line => { val (id, name) = line.span(_ != '\t') if (name.isEmpty) { None } else { try { Some((id.toInt, name.trim)) } catch { case e: NumberFormatException => None } } }) //将错误的艺术家id或不标准的id 映射成艺术家正规的名字 val rawArtistAlias = spark.read.textFile("/usr/local/mldata/artist_alias.txt") val artistAlias = rawArtistAlias.flatMap(line=>{ val tokens = line.split("\t") if(tokens(0).isEmpty){ None }else{ Some((tokens(0).toInt,tokens(1).toInt)) } }).rdd.collectAsMap() //将map变量广播 val bArtistAlias = spark.sparkContext.broadcast(artistAlias) val trainData = rawUserArtstData.map(line=>{ val Array(userId,artistId,count) = line.split(" ").map(_.toInt) val finalArtistId= bArtistAlias.value.getOrElse(artistId,artistId) Rating(userId,finalArtistId,count.toFloat) }).toDF().cache() val Array(train,test) = trainData.randomSplit(Array(0.8,0.2)) val als: ALS = new ALS().setMaxIter(5).setRegParam(0.01).setUserCol("userId").setItemCol("artistId").setRatingCol("count") val model: ALSModel = als.fit(train) //去掉userid或artistId 是NAN的 model.setColdStartStrategy("drop") // 保存模型 // model.save("") // //加载模型 // import org.apache.spark.ml.recommendation.ALS._ // val load1: ALS = load("") val predictions: DataFrame = model.transform(test) val evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("count").setPredictionCol("prediction") val rmse = evaluator.evaluate(predictions) println(s"Root-mean-square error = $rmse") //每个用户推荐的前十个电影 val userRecs: DataFrame = model.recommendForAllUsers(10) userRecs.rdd.saveAsTextFile("/usr/local/testdata/") //每个电影推荐的十个用户 val movieRecs = model.recommendForAllItems(10) movieRecs.rdd.saveAsTextFile("/usr/local/testdata/") userRecs.show() movieRecs.show() spark.stop() } }