package com.svg.ALS
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.{UserDefinedFunction, Window}
import org.apache.spark.sql.functions._
import scala.sys.exit
object ALS_ {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME", "hadoop")
val warehouse = "hdfs://LW-master:9820/user/hive/warehouse"
val spark: SparkSession = SparkSession.builder()
.appName("twa")
.master("local[2]")
.enableHiveSupport()
.config("spark.sql.warehouse.dir", warehouse)
.config("hive.metastore.uris", "thrift://LW-master:9083")
.config("dfs.client.use.datanode.hostname", "true")
.getOrCreate()
spark.sparkContext.setLogLevel("WARN")
/**
* 2.根据index表中的station_name字段,统计其出现次数,并根据出现次数划分评分区间(满分5分),创建评分列
* 3.根据order表中的sm_name字段,统计其出现次数,并根据出现次数划分评分区间(满分5分),创建评分列
*/
import spark.implicits._
val usermsg = spark.table("user_profile.usermsg_process").select("phone_no", "sm_code")
//计算用户接触的station_name评分
var index = spark.read.table("user_profile.media_index_process")
.filter("station_name!='NULL'")
.groupBy("phone_no","station_name")
.agg(count("*").alias("station_count"))
val max_station = index.agg(max("station_count")).collect()(0)(0).asInstanceOf[Long]
val min_station = index.agg(min("station_count")).collect()(0)(0).asInstanceOf[Long]
val range = (max_station - min_station)/5
val udf_score= udf((count: Int) =>count match {
case c if c > range * 4 => 5
case c if c > range * 3 => 4
case c if c > range * 2 => 3
case c if c > range * 1 => 2
case _ => 1
})
//设置station编号
val station = index.select("station_name").distinct().withColumn("station_id", row_number().over(Window.orderBy(col("station_name"))))
index=index.join(station,"station_name").join(usermsg,"phone_no")
index=index.withColumn("score",udf_score(col("station_count")))
.withColumn("phone_no",col("phone_no").cast("int"))
.select("phone_no","station_name","station_id","station_count","score")
index.show()
val re_sm_code = udf((sm_code: String) => {
if (sm_code.contains("甜果电视")) {
1
} else if (sm_code.contains("数字电视")) {
2
} else if (sm_code.contains("互动电视")) {
3
} else {
4
}
})
//计算用户接触的sm_name评分
var order = spark.read.table("user_profile.order_index_process")
.filter("sm_name!='NULL'")
.groupBy("phone_no","sm_name")
.agg(count("*").alias("sm_count"))
//划分评分
val max_sm_count = order.agg(max("sm_count")).collect()(0)(0).asInstanceOf[Long]
val min_sm_count = order.agg(min("sm_count")).collect()(0)(0).asInstanceOf[Long]
val sm_range = (max_sm_count - min_sm_count)/5
val sm_udf_score= udf((count: Int) =>count match {
case c if c > sm_range * 4 => 5
case c if c > sm_range * 3 => 4
case c if c > sm_range * 2 => 3
case c if c > sm_range * 1 => 2
case _ => 1
})
order=order.withColumn("score",sm_udf_score(col("sm_count")))
order = order.join(usermsg, "phone_no")
.withColumn("sm_code", re_sm_code(col("sm_name")))
.withColumn("phone_no",col("phone_no").cast("int"))
.select("phone_no", "sm_name", "sm_code","sm_count","score")
.orderBy(col("score").desc)
order.show()
//构建index的als模型
val indexals = new ALS()
.setMaxIter(10)
.setRegParam(0.01)
.setUserCol("phone_no")
.setItemCol("station_id")
.setRatingCol("score")
val indexalsInput = index.selectExpr("phone_no", "station_id", "score")
val Array(indextraid, indextest) = indexalsInput.randomSplit(Array(0.8, 0.2))
val model = indexals.fit(indextraid)
//根据测试集向用户推荐
val indexe_station= index.select("station_id").distinct()
val indexForUser = recommendForUser(2294549, "station_id", model, indexe_station, index, station, 6)
indexForUser.show()
val predictions = model.transform(indextest)
predictions.join(station,"station_id").select("phone_no","station_name","score","prediction").show()
//构建order的als模型
val orderals = new ALS()
.setMaxIter(10)
.setRegParam(0.01)
.setUserCol("phone_no")
.setItemCol("sm_code")
.setRatingCol("score")
val orderalsInput = order.selectExpr("phone_no", "sm_code", "score")
val Array(ordertraid, ordertest) = orderalsInput.randomSplit(Array(0.8, 0.2))
val ordermodel = orderals.fit(ordertraid)
//推荐
val orderindexDF=order.select("sm_name","sm_code").distinct()
val orderindex=order.select("sm_code").distinct()
val orderForUser = recommendForUser(2294549, "sm_code", ordermodel, orderindex, order, orderindexDF, 2)
orderForUser.show()
}
/**
*
* @param phone_no 用户
* @param unionx 预测列字符串
* @param model 模型
* @param goods 物品id 表
* @param indexDF 用户id 物品id表
* @param colDF 物品id 物品名字表
* @param numRecommendations 推荐个数
* @return
*/
def recommendForUser(phone_no: Int,unionx:String, model: ALSModel, goods: DataFrame, indexDF: DataFrame, colDF:DataFrame,numRecommendations: Int): DataFrame = {
// 获取用户未使用过的 station_name
var usedunionx = indexDF.where(col("phone_no")===lit(phone_no)).select(unionx).distinct()
usedunionx=goods.except(usedunionx).withColumn("phone_no",lit(phone_no)).distinct()
// 为用户预测评分
val userPredictions = model.transform(usedunionx).na.fill(0)
val predictions = userPredictions.join(colDF, unionx).orderBy(col("prediction").desc)
predictions.limit(numRecommendations)
}
}
基于广电数据ALS推荐算法
于 2024-05-10 15:20:29 首次发布