sparkmllib交替最小二乘法

http://spark.apache.org/docs/2.2.0/ml-collaborative-filtering.html

不需要用户和商品属性的信息,这类算法通常称为协同过滤算法

例子:根据两个用户的年龄相同来判断他们可能有相似的偏好,这不叫协同过滤。相反,根据两个用户播放过许多相同歌曲来判断他们可能都喜欢某首歌,这才叫协同过滤。

SparkMLlib 的ALS算法 要求用户和产品ID必须是数值型,这意味着大于Integer.MAX_VALUE(2147483647)的值都是非法的。

训练出的模型可以保存到文件,还可以从文件load模型

 
package test
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.ml.recommendation.{ALS, ALSModel}

import org.apache.spark.ml.Model
/**
  * Created by othc on 2018-01-19.
  */
object ALS1 {
  case class Rating(userId:Int,artistId:Int,count:Float)
  def main(args: Array[String]): Unit = {
    //session
    val spark = SparkSession.builder().config("spark.sql.warehouse.dir","/usr/local/testdata/spark-warehouse").appName("als").getOrCreate()
    import spark.implicits._
    //用户id 艺术家id 次数
    val rawUserArtstData: Dataset[String] = spark.read.textFile("/usr/local/mldata/user_artist_data.txt")
    //艺术家id 名字
    val rawArtistData =  spark.read.textFile("/usr/local/mldata/artist_data.txt")
    val artistById = rawArtistData.flatMap(line => {
      val (id, name) = line.span(_ != '\t')
      if (name.isEmpty) {
        None
      } else {
        try {
          Some((id.toInt, name.trim))
        } catch {
          case e: NumberFormatException => None
        }
      }
    })
    //将错误的艺术家id或不标准的id 映射成艺术家正规的名字
    val rawArtistAlias = spark.read.textFile("/usr/local/mldata/artist_alias.txt")
    val artistAlias = rawArtistAlias.flatMap(line=>{
      val tokens = line.split("\t")
      if(tokens(0).isEmpty){
        None
      }else{
        Some((tokens(0).toInt,tokens(1).toInt))
      }
    }).rdd.collectAsMap()
    //将map变量广播
    val bArtistAlias = spark.sparkContext.broadcast(artistAlias)

    val trainData = rawUserArtstData.map(line=>{
      val Array(userId,artistId,count) = line.split(" ").map(_.toInt)
      val finalArtistId= bArtistAlias.value.getOrElse(artistId,artistId)
      Rating(userId,finalArtistId,count.toFloat)
    }).toDF().cache()
    val Array(train,test) = trainData.randomSplit(Array(0.8,0.2))
    val als: ALS = new ALS().setMaxIter(5).setRegParam(0.01).setUserCol("userId").setItemCol("artistId").setRatingCol("count")

    val model: ALSModel = als.fit(train)
    //去掉userid或artistId 是NAN的
    model.setColdStartStrategy("drop")
//    保存模型
//    model.save("")
//    //加载模型
//    import org.apache.spark.ml.recommendation.ALS._
//    val load1: ALS = load("")

    val predictions: DataFrame = model.transform(test)
    val evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("count").setPredictionCol("prediction")
    val rmse = evaluator.evaluate(predictions)
    println(s"Root-mean-square error = $rmse")

    //每个用户推荐的前十个电影
    val userRecs: DataFrame = model.recommendForAllUsers(10)
    userRecs.rdd.saveAsTextFile("/usr/local/testdata/")

    //每个电影推荐的十个用户
    val movieRecs = model.recommendForAllItems(10)
    movieRecs.rdd.saveAsTextFile("/usr/local/testdata/")
    userRecs.show()
    movieRecs.show()

    spark.stop()
  }
}

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值