spark ml之推荐系统实现


    //查看下给定列个值得一个基本信息,例如数量 平均值 最大值 最小值 中位数这些基本描述
    rating.describe("userId","movieId","rating").show
  }
}

查看数据的平均值 中位数 最大值,最小值,对数据有个最基本的认识
在这里插入图片描述

将数据分文三部分:训练集 验证集 测试集

 val splits = rating.randomSplit(Array(0.6,0.2,0.2),seed=1234)// 1234为随机种子,为了保证每次的验证结果相同


然后就是构造als模型,利用循环优化参数
代码为:

 for(rank <- ranks;lambda <- lambdas; numIter <- numIters) {
      val als = new ALS()
        .setMaxIter(numIter)
        .setRegParam(lambda)
        .setRank(rank)
        .setNonnegative(true)
        .setUserCol("userId")
        .setItemCol("movieId")
        .setRatingCol("rating")
      val model: ALSModel = als.fit(training)
      val validationRmse = computeRmse(model,validation,numValidation)
      println("RMSE(validation) =" +validationRmse +" for the model trained with rank = "
      + rank + ",lambda =" +lambda +",and numIter =" + numIter + ".")
      if(validationRmse < bestValidationRmse) {
        bestLambda = lambda
        bestModel =Some(model)
        bestNumIter = numIter
        bestRank = rank
        bestValidationRmse = validationRmse
      }
    }

最终利用训练好的最好的模型来运用测试集进行测试:

  //用最佳模型预测测试集的评分,并计算和实际评分之间的均方根误差
    val testRmse = computeRmse(bestModel.get,test,numTest)
    println("The best model was trained with rank =" + bestRank +
    " and lambda = " +bestLambda
    +"and best numIter  = " + bestNumIter + ", and its RMSE on the best set is "
    + testRmse + ".")

最终的计算结果为:

在这里插入图片描述

最终的整体代码为:

package com.huawei.sparkml.rs

import org.apache.spark.SparkConf
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.sql.{DataFrame, SparkSession}

case class Rating(userId:Int,movieId:Int,rating:Float,timestamp:Long)
object RecommandSystem {
  def main(args: Array[String]): Unit = {
    def parseRating(str:String):Rating = {
      val fields = str.split("\t")
      assert(fields.size == 4)
        Rating(fields(0).toInt,fields(1).toInt,fields(2).toFloat,fields(3).toLong)
    }
    val conf = new SparkConf().setMaster("local").setAppName("RecommandSystem")
    val spark = SparkSession.builder().config(conf)getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    import spark.implicits._  //不加这一句,就没法转化成dataset对象,就会报错
    val rating = spark.read.textFile("u.data_little").map(line => parseRating(line)).cache()
//    rating.take(10).foreach(line =>println(line))
//    rating.show(4)
    //查看下给定列个值得一个基本信息,例如数量 平均值 最大值 最小值 中位数这些基本描述
//    rating.describe("userId","movieId","rating").show

    //构建模型
    //将数据按照8:2分为训练集和测试集
    val splits = rating.randomSplit(Array(0.6,0.2,0.2),seed=1234)
    //查看各个几个的数据量
    val training = splits(0).cache()
    val validation = splits(1).toDF().cache()
    val test = splits(2).toDF().cache()
    //计算各个集合总数
    val numTraining = training.count()
    val numValidation = validation.count()
    val numTest = test.count()
    //利用交替最小二乘法计算
    //优化并构建模型
    //初始化参数网络
    val ranks = List(10,20)
    val lambdas = List(0.01,0.1)
    val numIters = List(5,10)
    var bestModel:Option[ALSModel] = None
    var bestValidationRmse = Double.MaxValue
    var bestRank =0
    var bestLambda =1.0
    var bestNumIter =1

    //计算rmse
    def computeRmse(model:ALSModel,data:DataFrame,n:Long):Double={
      /**
       * +++++++++
       * +------+-------+------+---------+----------+
       * |userId|movieId|rating|timestamp|prediction|
       * +------+-------+------+---------+----------+
       * |224   |29     |3.0   |888104457|NaN       |
       * |196   |242    |3.0   |881250949|NaN       |
       * +------+-------+------+---------+----------+
       *
       * --------2
       * =====================
       * root
       * |-- userId: integer (nullable = false)
       * |-- movieId: integer (nullable = false)
       * |-- rating: float (nullable = false)
       * |-- timestamp: long (nullable = false)
       * |-- prediction: float (nullable = false)
       *
       *
       *
       * */


      val predictions = model.transform(data)
//      println("+++++++++")
//      predictions.show(false)
//      println("--------"+predictions.count())
//      println("=====================")
//      predictions.printSchema()
      //predictions.na.drop() 删除那些没有推荐项的数据
      /*
      join 完的数据形式:
      * ((224,29),(3.0,NaN))
        ((196,242),(3.0,NaN))
      * */
      //x(2) 为rating x(4) 为prediction  rating 为真rating值 prediction为预测值
      //这里需要大家对数据样式进行分析
      val p1 = predictions.na.drop().rdd.map{
        x => ((x(0),x(1)),x(2))
      }.join(predictions.rdd.map{x=>((x(0),x(1)),x(4))}).values //p1.为 rating |prediction值
//      println("p1.count="+p1.count())
//      p1.take(20).foreach(println(_))
      math.sqrt(p1.map(x =>(x._1.toString.toDouble - x._2.toString.toDouble) *
        (x._1.toString.toDouble - x._2.toString.toDouble)).reduce(_+_)/n)
    }

    for(rank <- ranks;lambda <- lambdas; numIter <- numIters) {
      val als = new ALS()
        .setMaxIter(numIter)
        .setRegParam(lambda)
        .setRank(rank)
        .setNonnegative(true)
        .setUserCol("userId")
        .setItemCol("movieId")
        .setRatingCol("rating")
      val model: ALSModel = als.fit(training)
      val validationRmse = computeRmse(model,validation,numValidation)
      println("RMSE(validation) =" +validationRmse +" for the model trained with rank = "
      + rank + ",lambda =" +lambda +",and numIter =" + numIter + ".")
      if(validationRmse < bestValidationRmse) {
        bestLambda = lambda
        bestModel =Some(model)
        bestNumIter = numIter
        bestRank = rank
        bestValidationRmse = validationRmse
      }
    }
    //用最佳模型预测测试集的评分,并计算和实际评分之间的均方根误差
    val testRmse = computeRmse(bestModel.get,test,numTest)
    println("The best model was trained with rank =" + bestRank +
    " and lambda = " +bestLambda
    +"and best numIter  = " + bestNumIter + ", and its RMSE on the best set is "
    + testRmse + ".")

  }
}

pom文件为:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.example</groupId>
    <artifactId>ml</artifactId>
    <version>1.0-SNAPSHOT</version>
    <properties>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
        <scala.version>2.12.12</scala.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-library</artifactId>
            <version>${scala.version}</version>
        </dependency>
        <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>4.11</version>
            <scope>test</scope>
        </dependency>
        <dependency> <!-- Spark dependency -->
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.12</artifactId>
            <version>2.4.3</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.12</artifactId>
            <version>2.4.3</version>
            <!--      <scope>provided</scope>-->
        </dependency>

    </dependencies>

    <build>
        <pluginManagement><!-- lock down plugins versions to avoid using Maven defaults (may be moved to parent pom) -->
            <plugins>
                <!-- clean lifecycle, see https://maven.apache.org/ref/current/maven-core/lifecycles.html#clean_Lifecycle -->
                <plugin>
                    <artifactId>maven-clean-plugin</artifactId>
                    <version>3.1.0</version>
                </plugin>
                <!-- default lifecycle, jar packaging: see https://maven.apache.org/ref/current/maven-core/default-bindings.html#Plugin_bindings_for_jar_packaging -->
                <plugin>
                    <artifactId>maven-resources-plugin</artifactId>
                    <version>3.0.2</version>
                </plugin>
                <plugin>
                    <artifactId>maven-compiler-plugin</artifactId>
                    <version>3.8.0</version>
                </plugin>
                <plugin>
                    <artifactId>maven-surefire-plugin</artifactId>
                    <version>2.22.1</version>
                </plugin>
                <plugin>
                    <artifactId>maven-jar-plugin</artifactId>
                    <version>3.0.2</version>
                </plugin>
                <plugin>
                    <artifactId>maven-install-plugin</artifactId>
                    <version>2.5.2</version>
                </plugin>
                <plugin>
                    <artifactId>maven-deploy-plugin</artifactId>
                    <version>2.8.2</version>
                </plugin>
                <!-- site lifecycle, see https://maven.apache.org/ref/current/maven-core/lifecycles.html#site_Lifecycle -->
                <plugin>
                    <artifactId>maven-site-plugin</artifactId>
                    <version>3.7.1</version>
                </plugin>
                <plugin>
                    <artifactId>maven-project-info-reports-plugin</artifactId>
                    <version>3.0.0</version>
                </plugin>
            </plugins>
        </pluginManagement>
    </build>

</project>

对应的数据文件格式为:

196 242 3 881250949
186 302 3 891717742
22 377 1 878887116
244 51 2 880606923
166 346 1 886397596
298 474 4 884182806
115 265 2 881171488

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值