创新实训(41)——在Springboot项目中使用Spark的ALS算法实现协同过滤推荐

前言

昨晚本来想把这部分的博客内容,完成的,结果只写到了设计,时间就不早了,今天把具体的实现,还有实现过程中所遇到的所有的问题写在这里。

引入依赖

这次我用了Spark2.x的java api,并且了解到spark底层是scala实现了,然后上层的api有scala版本和java版本,这里我使用了它提供的java的api,并且java底层调用的函数都是scala实现的,非常的方便,可以与java进行无缝的操作。
spark MLlib依赖:

  <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.12</artifactId>
            <version>2.4.4</version>
            <exclusions>
                <exclusion>
                    <groupId>com.google.guava</groupId>
                    <artifactId>guava</artifactId>
                </exclusion>
            </exclusions>
</dependency>

使用SparkSession链接Spark集群

   //初始化spark初始环境
 SparkSession spark = SparkSession.builder().master("local").appName("BlogStormApp").getOrCreate();

我们连接的是本地的master节点,在连接时出现了缺少依赖的问题

在这里插入图片描述
经过查阅资料之后,添加了两个依赖

  <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>14.0.1</version>
</dependency>
<dependency>
            <groupId>org.codehaus.janino</groupId>
            <artifactId>janino</artifactId>
            <version>3.0.8</version>
</dependency>

这连个依赖不是我们写的代码使用的,好像是Spark内部调用的方法使用的,缺少的话确实程序会运行失败的。

模型的训练

(1)使用spark的方法读取csv数据,然后序列化为javaRDD的格式

 /** 内部序列化的静态类
     * 数据模型   用户 分类  点击量
     */
    public  static class Rating implements Serializable{
        private  int userId;
        private  int itemId;
        private  int rating;
        public static Rating parseRating(String str){
//            str = str.replace("\"","");
            String[] strArr = str.split(",");
            int userId = Integer.parseInt(strArr[0]);
            int itemId = Integer.parseInt(strArr[1]);
            int rating = Integer.parseInt(strArr[2]);

            return new Rating(userId,itemId,rating);
        }

        public Rating()
        {

        }
        public Rating(int userId, int itemId, int rating) {
            this.userId = userId;
            this.itemId = itemId;
            this.rating = rating;
        }
        public int getUserId() {
            return userId;
        }

        public int getItemId() {
            return itemId;
        }

        public int getRating() {
            return rating;
        }

    }
 JavaRDD<String> csvFile = spark.read().textFile(DATA_DIR).toJavaRDD();

        JavaRDD<Rating> ratingJavaRDD = csvFile.map(new Function<String, Rating>() {
            @Override
            public Rating call(String s) {

                return Rating.parseRating(s);
            }
        });


(2)将JavaRDD的格式转化为DataFrame的格式(DataFrame中的列可以是存储的文本,特征向量,真实标签和预测的标签等),然后切分训练集和测试集,方便验证模型。
然后调整参数,调用方法进行模型的训练

        Dataset<Row> rating = spark.createDataFrame(ratingJavaRDD, Rating.class);

        //将所有的rating数据分成82分
        Dataset<Row>[] splits = rating.randomSplit(new double[]{0.8, 0.2});
        Dataset<Row> trainingData = splits[0];
        Dataset<Row> testingData = splits[1];
        //5个特征,
        //解决过拟合:1增大数据规模,减少rank,增大正则化系数
        //欠拟合:增加rank,减少正则化系数
        ALS als = new ALS().setMaxIter(10).setRank(5).setRegParam(0.01).setUserCol("userId")
                .setItemCol("itemId").setRatingCol("rating");
        ALSModel alsModel = als.fit(trainingData);

        //模型评测 将测试集使用模型做一次转化的预测
        Dataset<Row> predictions = alsModel.transform(testingData);

        //rmse 均方根误差,预测值与真实值的偏差的平方除以观测次数,开个根号
        RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse")
                .setLabelCol("rating").setPredictionCol("prediction");
        double rmse = evaluator.evaluate(predictions);
        System.out.println("rmse" + rmse);

(3)训练结果以及参数的调整(使用均方根误差来衡量)
均方根误差是预测值与真实值偏差的平方与观测次数n比值的平方根,在实际测量中,观测次数n总是有限的,真值只能用最可信赖(最佳)值来代替。精密度。这正是标准误差在工程测量中广泛被采用的原因。因此,标准差是用来衡量一组数自身的离散程度,而均方根误差是用来衡量观测值同真值之间的偏差,它们的研究对象和研究目的不同,但是计算过程类似 。
参数调整的原则:
解决过拟合:1增大数据规模,减少rank,增大正则化系数
欠拟合:增加rank,减少正则化系数

   ALS als = new ALS().setMaxIter(10).setRank(5).setRegParam(0.01).setUserCol("userId")
                .setItemCol("itemId").setRatingCol("rating");

当rank为5 正则化系数为0.01时的结果:
在这里插入图片描述

   ALS als = new ALS().setMaxIter(10).setRank(7).setRegParam(0.01).setUserCol("userId")
                .setItemCol("itemId").setRatingCol("rating");

当rank为7 正则化系数为0.01时的结果(发现提升rank值会增大误差,所以需要减少rank值):
在这里插入图片描述

   ALS als = new ALS().setMaxIter(10).setRank(4).setRegParam(0.01).setUserCol("userId")
                .setItemCol("itemId").setRatingCol("rating");

当rank为4 正则化系数为0.01时的结果(减小rank值之后,发现误差变小):
在这里插入图片描述

ALS als = new ALS().setMaxIter(10).setRank(3).setRegParam(0.02).setUserCol("userId")
                .setItemCol("itemId").setRatingCol("rating");

当rank为3 正则化系数为0.02时的结果(提升正则化系数,误差变大了,所以降低一下正则化系数):
在这里插入图片描述

ALS als = new ALS().setMaxIter(10).setRank(4).setRegParam(0.005).setUserCol("userId")
                .setItemCol("itemId").setRatingCol("rating");

当rank为4 正则化系数为0.005时的结果(减少正则化系数 误差经一步提高 ):
在这里插入图片描述

ALS als = new ALS().setMaxIter(10).setRank(3).setRegParam(0.01).setUserCol("userId")
                .setItemCol("itemId").setRatingCol("rating");

当rank为3 正则化系数为0.01时的结果( rank为3时 误差比rank为4大,所以重新调整rank):
在这里插入图片描述

ALS als = new ALS().setMaxIter(10).setRank(4).setRegParam(0.009).setUserCol("userId")
                .setItemCol("itemId").setRatingCol("rating");

当rank为4 正则化系数为0.09时的结果( 重新调整rank为4 正则化系数为0.09,效果不错):
在这里插入图片描述

ALS als = new ALS().setMaxIter(10).setRank(4).setRegParam(0.07).setUserCol("userId")
                .setItemCol("itemId").setRatingCol("rating");

当rank为4 正则化系数为0.07时的结果( 正则化系数为0.07,误差变大):
在这里插入图片描述

ALS als = new ALS().setMaxIter(10).setRank(4).setRegParam(0.1).setUserCol("userId")
                .setItemCol("itemId").setRatingCol("rating");

当rank为4 正则化系数为0.1时的结果( 正则化系数为0.1,误差变大):
在这里插入图片描述

所以还是rank为4,且正则化系数为0.09比较合适
在这里插入图片描述
(4)模型的存储

/**
     * 删除文件夹下所有文件
     *
     * @param file
     */
    private static void deleteFile(File file) {
        //判断是否为文件,是,则删除
        if (file.isFile()) {

            file.delete();
        } else {//不为文件,则为文件夹
            //获取文件夹下所有文件相对路径
            String[] childFilePath = file.list();
            for (String path : childFilePath) {
                File childFile = new File(file.getAbsoluteFile() + "/" + path);
                //递归,对每个都进行判断
                deleteFile(childFile);
            }
            file.delete();
        }
    }
 //删除文件夹下所有文件
        File file = new File(MODEL_DIR);
        deleteFile(file);
        //缓存模型
        alsModel.save(MODEL_DIR);

读取模型进行预测

(1)模型的读取

 //初始化spark运行环境
        SparkSession spark = SparkSession.builder().master("local").appName("BlogStormApp").getOrCreate();
        //将模型加载到内存之中
        ALSModel alsModel = ALSModel.load(MODEL_DIR);

(2)进行预测

 JavaRDD<String> csvFile = spark.read().textFile(DATA_DIR).toJavaRDD();
        JavaRDD<Rating> ratingJavaRDD = csvFile.map(new Function<String, Rating>() {
            @Override
            public Rating call(String s) throws Exception {
                return Rating.parseRating(s);
            }
        });

        Dataset<Row> rating = spark.createDataFrame(ratingJavaRDD,Rating.class);
        //用户做离线的召回结果预测
        Dataset<Row> users = rating.select(alsModel.getUserCol());
        Dataset<Row> userRecs = alsModel.recommendForUserSubset(users,numItems).distinct().limit(userNumber);
        userRecs.foreachPartition(new ForeachPartitionFunction<Row>() {
            @Override
            public void call(Iterator<Row> iterator) throws Exception {
                List<String> data = new ArrayList<>();
                iterator.forEachRemaining(action->{
                    int userId = action.getInt(0);
                    System.out.println(userId);
                    List<GenericRowWithSchema> recommendationList = action.getList(1);
                    List<Integer> shopIdList = new ArrayList<Integer>();
                    recommendationList.forEach(row->{
                        Integer tagId = row.getInt(0);
                        shopIdList.add(tagId);
                    });
                    String recommendData = StringUtils.join(shopIdList,"#");
                    System.out.println(recommendData);

                    data.add(String.valueOf(userId) +","+ recommendData);
//                    storeRecommendTag(userId,recommendData);

                });
                //删除文件夹下所有文件
                File file = new File(PREDICT_DIR);
                deleteFile(file);
                csv.writeCsvFile(PREDICT_DIR,data);
            }
        });

问题:
本来在预测完结果后,希望直接调用redis,进行数据的存储操作,但是发现在链接Spark的代码文件中,使用@Resource进行注入时,回报没有序列化的错误
在这里插入图片描述
尝试各种方式,将调用的方法全部实现序列化的接口,但还是不行
在这里插入图片描述
在这里插入图片描述
最后,只能放弃在这个文件中进行依赖的注入。我先讲预测的结果存入csv文件中,在使用其他方法进行redis的写入。

推荐结果的写入和读取

在外部文件中调用相应的程序,实现模型的训练,然后读取预测好的结果,存入redis,并且预留调用的接口,方便进行调用
在这里插入图片描述

@Component
public class AlsRecommended {

    /**
     *数据存放的位置
     */
    private static String TAG_DATA_DIR = "src/main/resources/UserBrowsingData/user-tag/user_tag.csv";
    /**
     * 模型的存放位置
     */
    private static String TAG_MODEL_DIR = "src/main/resources/UserBrowsingData/user-tag/model";

    /**
     * 预测结果存放的位置
     */
    private static String TAG_PREDICE_DIR = "src/main/resources/UserBrowsingData/user-tag/predict/predict.csv";
    /**
     * redis的公共key  然后与userId进行拼接
     */
    private static String TAG_REDIS_KEY = "alsTagRecommended";
    /**
     *数据存放的位置
     */
    private static String CATEGORY_DATA_DIR = "src/main/resources/UserBrowsingData/user-category/user_category.csv";
    /**
     * 模型的存放位置
     */
    private static String CATEGORY_MODEL_DIR = "src/main/resources/UserBrowsingData/user-category/model";

    /**
     * 预测结果存放的位置
     */
    private static String CATEGORY_Predict_DIR = "src/main/resources/UserBrowsingData/user-category/predict/predict.csv";
    /**
     * redis的公共key  然后与userId进行拼接
     */
    private static String CATEGORY_REDIS_KEY = "alsCategoryRecommended";

    @Resource
    private AlsRecall alsRecall;
    @Resource
   private UserService userService;

    @Resource
    private RedisUtils redisUtils;

    /**
     * 初始化tag的推荐结果
     * @throws IOException
     */
    public void initTagRecommendedResult() throws IOException {

        alsRecall.trainAls(TAG_DATA_DIR,TAG_MODEL_DIR);
        int userNum = userService.getUserCount();
        alsRecall.predict(TAG_MODEL_DIR,TAG_DATA_DIR,TAG_PREDICE_DIR,userNum+1,10);

        List<String>  recommendedTags = csv.readCsvFile(TAG_PREDICE_DIR);
        if(recommendedTags != null)
        {
            for(String str :recommendedTags)
            {
                String list [] = str.split(",");
                storeRecommendTag(Integer.valueOf(list[0]),list[1]);
            }
        }

    }
    /**
     * 初始化category的推荐结果
     * @throws IOException
     */
    public void initCategoryRecommendedResult() throws IOException {

        alsRecall.trainAls(CATEGORY_DATA_DIR,CATEGORY_MODEL_DIR);
        int userNum = userService.getUserCount();
        alsRecall.predict(CATEGORY_MODEL_DIR,CATEGORY_DATA_DIR,CATEGORY_Predict_DIR,userNum+1,3);

        List<String>  recommendedCategorys = csv.readCsvFile(CATEGORY_Predict_DIR);
        if(recommendedCategorys != null)
        {
            for(String str :recommendedCategorys)
            {
                String list [] = str.split(",");
                storeRecommendCategory(Integer.valueOf(list[0]),list[1]);
            }
        }

    }

    /**
     * 存储标签的推荐的结果到redis
     * @param userId
     * @param recommendData  为tag的列表  用","分割  存入redis
     */
    public void storeRecommendTag(int userId,String recommendData)
    {
            redisUtils.set(TAG_REDIS_KEY+String.valueOf(userId),recommendData);
    }

    /**
     *从redis中读取标签的推荐结果
     * @param userId
     * @return
     */

    public String getRecommendTag(int userId)
    {
        return redisUtils.get(TAG_REDIS_KEY+String.valueOf(userId));
    }
    /**
     *从redis中读取分类的推荐结果
     * @param userId
     * @return
     */

    public String getRecommendCategory(int userId)
    {
        return redisUtils.get(CATEGORY_REDIS_KEY+String.valueOf(userId));
    }


    /**
     * 存储分类的推荐的结果到redis
     * @param userId
     * @param recommendData  为tag的列表  用","分割  存入redis
     */
    public void storeRecommendCategory(int userId,String recommendData)
    {
        redisUtils.set(CATEGORY_REDIS_KEY+String.valueOf(userId),recommendData);
    }

}

实现文章的推荐

读入redis提前推荐的标签,然后使用标签找到相应的文章进行读取


    /**
     * 基于ALS的 推荐标签 进而推荐文章
     * @param userId
     * @return
     */
    public List<Integer> recommendArticleByTagALS(int userId)
    {
        List<Integer> articleList = null;
        try {
            //推荐的标签
            String tagString = alsRecommended.getRecommendTag(userId);
            //将标签分割为数组
            String temp[] = tagString.split("#");

            List<Integer> tagIdList = new ArrayList<>();
            for (int i = 0; i < temp.length; i++) {
                tagIdList.add(Integer.valueOf(temp[i]));
            }
            articleList = new ArrayList<>();
            //根据标签 随机选出20篇文章
            if (tagIdList.size() != 0) {
                articleList = articleTagDao.selectArticleByTagList(userId, tagIdList, 20);
            }
        }
        catch (Exception e)
        {articleList = new ArrayList<>();

        }
        return articleList;
    }
  public List<Integer> recommendArticleByCategoryALS(int userId)
    {
        List<Integer> articleList = null;
        try {
            //推荐的标签
            String categoryString = alsRecommended.getRecommendCategory(userId);
            //将标签分割为数组
            String temp [] = categoryString.split("#");

            List<Integer> categoryIdList = new ArrayList<>();
            for(int i = 0 ; i < temp.length ; i++)
            {
                categoryIdList.add(Integer.valueOf(temp[i]));
            }
             articleList = new ArrayList<>();
            //根据标签 随机选出20篇文章
            if(categoryIdList.size() != 0)
            {
                articleList = articleCategoryDao.selectArticleByCategoryList(userId,categoryIdList,20);
            }
        }
        catch (Exception e)
        {
            articleList = new ArrayList<>();
        }



        return articleList;

    }

总结

这就是基于spark进行推荐的全部内容。这个推荐是离线推荐的,需要在每天晚上定时的抽取用户的浏览记录,然后使用Sprak提前对推荐的物品进行预计算,然后存入redis中,之后一天的推荐都是基于redis数据的离线推荐。

  • 6
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值