//查看下给定列个值得一个基本信息,例如数量 平均值 最大值 最小值 中位数这些基本描述
rating.describe("userId","movieId","rating").show
}
}
查看数据的平均值 中位数 最大值,最小值,对数据有个最基本的认识
将数据分文三部分:训练集 验证集 测试集
val splits = rating.randomSplit(Array(0.6,0.2,0.2),seed=1234)// 1234为随机种子,为了保证每次的验证结果相同
然后就是构造als模型,利用循环优化参数
代码为:
for(rank <- ranks;lambda <- lambdas; numIter <- numIters) {
val als = new ALS()
.setMaxIter(numIter)
.setRegParam(lambda)
.setRank(rank)
.setNonnegative(true)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating")
val model: ALSModel = als.fit(training)
val validationRmse = computeRmse(model,validation,numValidation)
println("RMSE(validation) =" +validationRmse +" for the model trained with rank = "
+ rank + ",lambda =" +lambda +",and numIter =" + numIter + ".")
if(validationRmse < bestValidationRmse) {
bestLambda = lambda
bestModel =Some(model)
bestNumIter = numIter
bestRank = rank
bestValidationRmse = validationRmse
}
}
最终利用训练好的最好的模型来运用测试集进行测试:
//用最佳模型预测测试集的评分,并计算和实际评分之间的均方根误差
val testRmse = computeRmse(bestModel.get,test,numTest)
println("The best model was trained with rank =" + bestRank +
" and lambda = " +bestLambda
+"and best numIter = " + bestNumIter + ", and its RMSE on the best set is "
+ testRmse + ".")
最终的计算结果为:
最终的整体代码为:
package com.huawei.sparkml.rs
import org.apache.spark.SparkConf
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.sql.{DataFrame, SparkSession}
case class Rating(userId:Int,movieId:Int,rating:Float,timestamp:Long)
object RecommandSystem {
def main(args: Array[String]): Unit = {
def parseRating(str:String):Rating = {
val fields = str.split("\t")
assert(fields.size == 4)
Rating(fields(0).toInt,fields(1).toInt,fields(2).toFloat,fields(3).toLong)
}
val conf = new SparkConf().setMaster("local").setAppName("RecommandSystem")
val spark = SparkSession.builder().config(conf)getOrCreate()
spark.sparkContext.setLogLevel("WARN")
import spark.implicits._ //不加这一句,就没法转化成dataset对象,就会报错
val rating = spark.read.textFile("u.data_little").map(line => parseRating(line)).cache()
// rating.take(10).foreach(line =>println(line))
// rating.show(4)
//查看下给定列个值得一个基本信息,例如数量 平均值 最大值 最小值 中位数这些基本描述
// rating.describe("userId","movieId","rating").show
//构建模型
//将数据按照8:2分为训练集和测试集
val splits = rating.randomSplit(Array(0.6,0.2,0.2),seed=1234)
//查看各个几个的数据量
val training = splits(0).cache()
val validation = splits(1).toDF().cache()
val test = splits(2).toDF().cache()
//计算各个集合总数
val numTraining = training.count()
val numValidation = validation.count()
val numTest = test.count()
//利用交替最小二乘法计算
//优化并构建模型
//初始化参数网络
val ranks = List(10,20)
val lambdas = List(0.01,0.1)
val numIters = List(5,10)
var bestModel:Option[ALSModel] = None
var bestValidationRmse = Double.MaxValue
var bestRank =0
var bestLambda =1.0
var bestNumIter =1
//计算rmse
def computeRmse(model:ALSModel,data:DataFrame,n:Long):Double={
/**
* +++++++++
* +------+-------+------+---------+----------+
* |userId|movieId|rating|timestamp|prediction|
* +------+-------+------+---------+----------+
* |224 |29 |3.0 |888104457|NaN |
* |196 |242 |3.0 |881250949|NaN |
* +------+-------+------+---------+----------+
*
* --------2
* =====================
* root
* |-- userId: integer (nullable = false)
* |-- movieId: integer (nullable = false)
* |-- rating: float (nullable = false)
* |-- timestamp: long (nullable = false)
* |-- prediction: float (nullable = false)
*
*
*
* */
val predictions = model.transform(data)
// println("+++++++++")
// predictions.show(false)
// println("--------"+predictions.count())
// println("=====================")
// predictions.printSchema()
//predictions.na.drop() 删除那些没有推荐项的数据
/*
join 完的数据形式:
* ((224,29),(3.0,NaN))
((196,242),(3.0,NaN))
* */
//x(2) 为rating x(4) 为prediction rating 为真rating值 prediction为预测值
//这里需要大家对数据样式进行分析
val p1 = predictions.na.drop().rdd.map{
x => ((x(0),x(1)),x(2))
}.join(predictions.rdd.map{x=>((x(0),x(1)),x(4))}).values //p1.为 rating |prediction值
// println("p1.count="+p1.count())
// p1.take(20).foreach(println(_))
math.sqrt(p1.map(x =>(x._1.toString.toDouble - x._2.toString.toDouble) *
(x._1.toString.toDouble - x._2.toString.toDouble)).reduce(_+_)/n)
}
for(rank <- ranks;lambda <- lambdas; numIter <- numIters) {
val als = new ALS()
.setMaxIter(numIter)
.setRegParam(lambda)
.setRank(rank)
.setNonnegative(true)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating")
val model: ALSModel = als.fit(training)
val validationRmse = computeRmse(model,validation,numValidation)
println("RMSE(validation) =" +validationRmse +" for the model trained with rank = "
+ rank + ",lambda =" +lambda +",and numIter =" + numIter + ".")
if(validationRmse < bestValidationRmse) {
bestLambda = lambda
bestModel =Some(model)
bestNumIter = numIter
bestRank = rank
bestValidationRmse = validationRmse
}
}
//用最佳模型预测测试集的评分,并计算和实际评分之间的均方根误差
val testRmse = computeRmse(bestModel.get,test,numTest)
println("The best model was trained with rank =" + bestRank +
" and lambda = " +bestLambda
+"and best numIter = " + bestNumIter + ", and its RMSE on the best set is "
+ testRmse + ".")
}
}
pom文件为:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>ml</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<scala.version>2.12.12</scala.version>
</properties>
<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
<dependency> <!-- Spark dependency -->
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.12</artifactId>
<version>2.4.3</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.12</artifactId>
<version>2.4.3</version>
<!-- <scope>provided</scope>-->
</dependency>
</dependencies>
<build>
<pluginManagement><!-- lock down plugins versions to avoid using Maven defaults (may be moved to parent pom) -->
<plugins>
<!-- clean lifecycle, see https://maven.apache.org/ref/current/maven-core/lifecycles.html#clean_Lifecycle -->
<plugin>
<artifactId>maven-clean-plugin</artifactId>
<version>3.1.0</version>
</plugin>
<!-- default lifecycle, jar packaging: see https://maven.apache.org/ref/current/maven-core/default-bindings.html#Plugin_bindings_for_jar_packaging -->
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>3.0.2</version>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.0</version>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.22.1</version>
</plugin>
<plugin>
<artifactId>maven-jar-plugin</artifactId>
<version>3.0.2</version>
</plugin>
<plugin>
<artifactId>maven-install-plugin</artifactId>
<version>2.5.2</version>
</plugin>
<plugin>
<artifactId>maven-deploy-plugin</artifactId>
<version>2.8.2</version>
</plugin>
<!-- site lifecycle, see https://maven.apache.org/ref/current/maven-core/lifecycles.html#site_Lifecycle -->
<plugin>
<artifactId>maven-site-plugin</artifactId>
<version>3.7.1</version>
</plugin>
<plugin>
<artifactId>maven-project-info-reports-plugin</artifactId>
<version>3.0.0</version>
</plugin>
</plugins>
</pluginManagement>
</build>
</project>
对应的数据文件格式为:
196 242 3 881250949
186 302 3 891717742
22 377 1 878887116
244 51 2 880606923
166 346 1 886397596
298 474 4 884182806
115 265 2 881171488