基于广电数据ALS推荐算法

package com.svg.ALS

import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.{UserDefinedFunction, Window}
import org.apache.spark.sql.functions._

import scala.sys.exit

object ALS_ {

  def main(args: Array[String]): Unit = {
    System.setProperty("HADOOP_USER_NAME", "hadoop")
    val warehouse = "hdfs://LW-master:9820/user/hive/warehouse"
    val spark: SparkSession = SparkSession.builder()
      .appName("twa")
      .master("local[2]")
      .enableHiveSupport()
      .config("spark.sql.warehouse.dir", warehouse)
      .config("hive.metastore.uris", "thrift://LW-master:9083")
      .config("dfs.client.use.datanode.hostname", "true")
      .getOrCreate()
    spark.sparkContext.setLogLevel("WARN")

    /**
     * 2.根据index表中的station_name字段,统计其出现次数,并根据出现次数划分评分区间(满分5分),创建评分列
     * 3.根据order表中的sm_name字段,统计其出现次数,并根据出现次数划分评分区间(满分5分),创建评分列
     */
    import spark.implicits._
    val usermsg = spark.table("user_profile.usermsg_process").select("phone_no", "sm_code")

    //计算用户接触的station_name评分
    var index = spark.read.table("user_profile.media_index_process")
      .filter("station_name!='NULL'")
      .groupBy("phone_no","station_name")
      .agg(count("*").alias("station_count"))

    val max_station = index.agg(max("station_count")).collect()(0)(0).asInstanceOf[Long]
    val min_station =  index.agg(min("station_count")).collect()(0)(0).asInstanceOf[Long]
    val range = (max_station - min_station)/5
    val udf_score= udf((count: Int) =>count match {
      case c if c > range * 4 => 5
      case c if c > range * 3 => 4
      case c if c > range * 2 => 3
      case c if c > range * 1 => 2
      case _ => 1
    })
    //设置station编号
    val station = index.select("station_name").distinct().withColumn("station_id", row_number().over(Window.orderBy(col("station_name"))))
    index=index.join(station,"station_name").join(usermsg,"phone_no")
    index=index.withColumn("score",udf_score(col("station_count")))
      .withColumn("phone_no",col("phone_no").cast("int"))
      .select("phone_no","station_name","station_id","station_count","score")
    index.show()


    val re_sm_code = udf((sm_code: String) => {
      if (sm_code.contains("甜果电视")) {
        1
      } else if (sm_code.contains("数字电视")) {
        2
      } else if (sm_code.contains("互动电视")) {
        3
      } else {
        4
      }
    })
    //计算用户接触的sm_name评分
    var order = spark.read.table("user_profile.order_index_process")
      .filter("sm_name!='NULL'")
      .groupBy("phone_no","sm_name")
      .agg(count("*").alias("sm_count"))
    //划分评分
    val max_sm_count = order.agg(max("sm_count")).collect()(0)(0).asInstanceOf[Long]
    val min_sm_count =  order.agg(min("sm_count")).collect()(0)(0).asInstanceOf[Long]
    val sm_range = (max_sm_count - min_sm_count)/5
    val sm_udf_score= udf((count: Int) =>count match {
      case c if c > sm_range * 4 => 5
      case c if c > sm_range * 3 => 4
      case c if c > sm_range * 2 => 3
      case c if c > sm_range * 1 => 2
      case _ => 1
    })
    order=order.withColumn("score",sm_udf_score(col("sm_count")))
    order = order.join(usermsg, "phone_no")
      .withColumn("sm_code", re_sm_code(col("sm_name")))
      .withColumn("phone_no",col("phone_no").cast("int"))
      .select("phone_no", "sm_name", "sm_code","sm_count","score")
      .orderBy(col("score").desc)
    order.show()


    //构建index的als模型
    val indexals = new ALS()
      .setMaxIter(10)
      .setRegParam(0.01)
      .setUserCol("phone_no")
      .setItemCol("station_id")
      .setRatingCol("score")
    val indexalsInput = index.selectExpr("phone_no", "station_id", "score")
    val Array(indextraid, indextest) = indexalsInput.randomSplit(Array(0.8, 0.2))
    val model = indexals.fit(indextraid)

    //根据测试集向用户推荐
    val indexe_station= index.select("station_id").distinct()
    val indexForUser = recommendForUser(2294549, "station_id", model, indexe_station, index, station, 6)
    indexForUser.show()
    val predictions = model.transform(indextest)
    predictions.join(station,"station_id").select("phone_no","station_name","score","prediction").show()

    //构建order的als模型
    val orderals = new ALS()
      .setMaxIter(10)
      .setRegParam(0.01)
      .setUserCol("phone_no")
      .setItemCol("sm_code")
      .setRatingCol("score")
    val orderalsInput = order.selectExpr("phone_no", "sm_code", "score")
    val Array(ordertraid, ordertest) = orderalsInput.randomSplit(Array(0.8, 0.2))
    val ordermodel = orderals.fit(ordertraid)
    //推荐
    val orderindexDF=order.select("sm_name","sm_code").distinct()
    val orderindex=order.select("sm_code").distinct()
    val orderForUser = recommendForUser(2294549, "sm_code", ordermodel, orderindex, order, orderindexDF, 2)
    orderForUser.show()

  }

  /**
   *
   * @param phone_no 用户
   * @param unionx 预测列字符串
   * @param model 模型
   * @param goods  物品id 表
   * @param indexDF 用户id 物品id表
   * @param colDF 物品id 物品名字表
   * @param numRecommendations 推荐个数
   * @return
   */
  def recommendForUser(phone_no: Int,unionx:String, model: ALSModel, goods: DataFrame, indexDF: DataFrame, colDF:DataFrame,numRecommendations: Int): DataFrame = {
    // 获取用户未使用过的 station_name
    var usedunionx = indexDF.where(col("phone_no")===lit(phone_no)).select(unionx).distinct()
    usedunionx=goods.except(usedunionx).withColumn("phone_no",lit(phone_no)).distinct()

    // 为用户预测评分
    val userPredictions = model.transform(usedunionx).na.fill(0)
    val predictions = userPredictions.join(colDF, unionx).orderBy(col("prediction").desc)
    predictions.limit(numRecommendations)
  }

}

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值