上一篇文章,主要简单介绍了一些推荐相关的算法和ALS的实现,本篇将通过代码介绍spark2 ALS的使用,语言是scala。下面先上一部分maven:
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<spark.version>2.2.0</spark.version>
<scala.version>2.11</scala.version>
</properties>
<!-- scala -->
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>2.11.11</version>
</dependency>
<dependency>
<groupId>org.scala-lang.modules</groupId>
<artifactId>scala-xml_2.11</artifactId>
<version>1.0.4</version>
</dependency>
<!-- spark -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<!--mysql-->
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>5.0.5</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.6.5</version>
</dependency>
谈到spark2首先要介绍两个东西:
一、SparkSession:
The entry point to programming Spark with the Dataset and DataFrame API. 使用Dataset和DataFrame API的spark程序的切入点,它提供了一个统一的入口。
在机器学习模块中,DataFrame 目前正渐渐取代RDD的API。RDD的API进入维护模式。SparkSession实质上是SQLContext和HiveContext的组合(未来可能还会
加上StreamingContext),所以在SQLContext和HiveContext上可用的API在SparkSession上同样是可以使用的。SparkSession内部封装了sparkContext,所以计
算实际上是由sparkContext完成的。
二、DataFrame:
DataFrame是一个基于列(column)的数据集,在形式上可以看成关系型数据库的一张表。直接这么说可能有点抽象,下面直接上图:
上图就是一个DataFrame,它有三列,是不是感觉很像我们使用的数据库的表。DataFrame的详细操作可以参考相关API。
下面介绍使用ALS的代码:
一、创建SparkSession:
val spark = SparkSession.builder
.master("local[4]")
.appName("test1")
.getOrCreate();
上面代码创建了一个SparkSession,它的名字是test1,master中local[4]创建了本地多线程,这里模拟使用4个内核,直接写local表示本地单线程。
二、从数据库中读取数据:
val mRecord = spark.read
.format("jdbc")
.option("url", "jdbc:mysql://192.168.1.154:3306/recommend?user=root&password=root")
.option("dbtable", "(select vip_id,store_id,times from recommend_consumption_times) as record") //数据
.option("driver","com.mysql.jdbc.Driver")
.load();
mRecord.show();
表示顾客光顾店铺的次数。最后mRecord.show()的结果就是上图的那个DataFrame。
mRecord.printSchema()
查看一下三个字段的数据类型,ALS支持的数据类型是 int int double,user和item的id都要是int,评分是double。
三、使用ALS推荐:
val als = new ALS()
.setMaxIter(5)
.setRegParam(0.1)
.setUserCol("vip_id")
.setItemCol("store_id")
.setRatingCol("times");
val model = als.fit(mRecord);
val userRecs = model.recommendForAllUsers(10);
userRecs.show(false);
首先new一个ALS,maxIter:计算时的最大迭代次数,默认10次。regParam正则化参数,默认1。还有其他参数可以参见相关API。然后,根据我们平分
记录训练一个模型model,这个model就可以用来推荐了,这里为每个用户推荐10个商品(评分最高的十个)。结果如下图:
其中vip_id是我们输入的用户,recommendations是推荐结果,推荐结果中前一个表示推荐的品牌id,后面是对这个品牌模拟的评分。
四、提取推荐的结果:
val df = userRecs.withColumn("recommendations",userRecs("recommendations").getField("store_id").cast(StringType)).withColumnRenamed("recommendations","store_ids");
df.show(false);
对上面的推荐结果进行处理,将recommendations中的结果只保留店铺id,去掉评分,然后把列的名字改为store_ids。结果如下:
五、将结果导入到数据库中:
val prop = new Properties();
prop.setProperty("user","root");
prop.setProperty("password","root");
prop.setProperty("driver","com.mysql.jdbc.Driver");
//写入数据库
df.write.mode(SaveMode.Overwrite).jdbc("jdbc:mysql://192.168.1.154:3306/recommend", "recommend_result",prop);
prop配置了数据库连接的属性,然后将数据写入recommend_result表中。其中SaveMode.Overwrite表示每次重新写入,如果希望每次把推荐结果追加到
后面,可以使用SaveMode.Append。
至此,我们就是用spark2的ALS算法完成了一个简单的推荐。