SparkML 实现 ALS 算法

引入依赖

<dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-mllib_2.12</artifactId>
    <version>2.4.4</version>
    <exclusions>
        <exclusion>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
        </exclusion>
    </exclusions>
</dependency>
<dependency>
    <groupId>com.google.guava</groupId>
    <artifactId>guava</artifactId>
    <version>14.0.1</version>
</dependency>

数据准备

门店数据
  • 通过 dml.sql 导入了 400 条数据;
行为数据
  • 保存在文件 behavior.csv 中,总共 3 列,第一列 userId,第二列 shopId,第三列用户对这个门店的钟爱度打分;
  • behavior.csv 中大概有 2 万多条数据;

离线 ALS 召回模型的训练

离线 ALS 召回模型的训练 | 过程
  • 读行为数据 behavior.csv 到内存中;
  • 转换数据结构:JavaRDD<String> -> JavaRDD<Rating> -> Dataset<Row>;
  • 按 8-2 分,将行为数据集分成 2 份,一份训练用,一份测试用;
  • 设置 ALS 模型的参数:.setMaxIter(10).setRank(5).setRegParam(0.01)
  • 生成模型;
  • 生成模型测评器;
  • 用测试行为数据,测试生成的模型,得到 rmse 得分;
  • 生成的模型可以保存在磁盘;
模型生成的结果
  • alsmodel
    • itemFactor - 存储门店训练出来的特征值;
    • metadata
    • userFactors - 存储用户训练出来的特征值,二进制的;
离线 ALS 召回模型的训练 | 代码
package tech.lixinlei.dianping.recommand;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

import java.io.IOException;
import java.io.Serializable;


/**
 * ALS 召回算法的训练
 * 实现 Serializable 是因为,Spark 的程序可以运行在不同的机器上;
 */
public class AlsRecallTrain implements Serializable {

    public static void main(String[] args) throws IOException {

        //初始化spark运行环境
        SparkSession spark = SparkSession.builder().master("local").appName("DianpingApp").getOrCreate();

        JavaRDD<String> csvFile = spark.read().textFile("file:///home/lixinlei/project/gitee/dianping/src/main/resources/behavior.csv").toJavaRDD();

        JavaRDD<Rating> ratingJavaRDD = csvFile.map(new Function<String, Rating>() {
            /**
             * 将 behavior.csv 中的一行,从 String 转成 Rating;
             * @param v1 behavior.csv 中数据的一行
             * @return
             * @throws Exception
             */
            @Override
            public Rating call(String v1) throws Exception {
                return Rating.parseRating(v1);
            }
        });

        // Dataset 可以理解为 MySQL 中的一张表,row 中 column 的定义遵从 Rating 的定义;
        Dataset<Row> rating = spark.createDataFrame(ratingJavaRDD, Rating.class);

        // 将所有的 rating 数据分成 8-2 分,80% 的数据用来做训练,20% 的训练用来做测试
        Dataset<Row>[] splits = rating.randomSplit(new double[]{0.8, 0.2});
        Dataset<Row> trainingData = splits[0];
        Dataset<Row> testingData = splits[1];

        // .setMaxIter(10) 设置最大拟合次数
        // .setRank(5) 分解矩阵后 feature 的数量
        // .setRegParam(0.01) 正则化系数,增大正则化的值,可以防止过拟合的情况
        // 过拟合:指得是模型训练出来的内容,过分的逼近真实数据,导致一旦真实数据出现一些误差,预测的结果反而不尽如人意;
        // 欠拟合:模型训练出来的内容,没有达到收敛于真是数据,使得预测结果的偏差距离真实结果太大;
        // 过拟合的解决方案:1)增大数据规模 2)减少 RANK,即特征的数量,使得模型预测的能力更加松散 3)增大正则化的系数
        // 欠拟合的解决方案:1)增加 RANK 2)减少正则化系数
        ALS als = new ALS().setMaxIter(10).setRank(5).setRegParam(0.01).
                setUserCol("userId").setItemCol("shopId").setRatingCol("rating");

        // 模型训练
        ALSModel alsModel = als.fit(trainingData);

        // 模型评测:测评的时候,用到了 testingData 中的 userId 和 shopId 字段的值,没有用 rating 字段的值,而且计算出了一个新字段,叫 prediction
        Dataset<Row> predictions = alsModel.transform(testingData);

        // rmse 均方根误差,预测值与真实值的偏差的平方除以观测次数(testingData 的条数),开个根号
        // rmse 的值越小,标识模型在测试数据集上的表现越好;
        RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse")
                .setLabelCol("rating").setPredictionCol("prediction");
        double rmse = evaluator.evaluate(predictions);
        System.out.println("rmse = " + rmse);

        alsModel.save("file:///home/lixinlei/project/gitee/dianping/src/main/resources/alsmodel");
    }

    /**
     * 自定义数据结构,用来承接 behavior.csv 中的一行数据;
     */
    public static class Rating implements Serializable{

        private int userId;
        private int shopId;
        private int rating;

        /**
         * 将 hebavior.csv 中的一行数据,组装成 Rating 对象返回;
         * @param str behavior.csv 文件的一行输入
         * @return
         */
        public static Rating parseRating(String str){
            str = str.replace("\"","");
            String[] strArr = str.split(",");
            int userId = Integer.parseInt(strArr[0]);
            int shopId = Integer.parseInt(strArr[1]);
            int rating = Integer.parseInt(strArr[2]);
            return new Rating(userId,shopId,rating);
        }

        public Rating(int userId, int shopId, int rating) {
            this.userId = userId;
            this.shopId = shopId;
            this.rating = rating;
        }

        public int getUserId() {
            return userId;
        }

        public int getShopId() {
            return shopId;
        }

        public int getRating() {
            return rating;
        }
    }

}
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值