前言
昨晚本来想把这部分的博客内容,完成的,结果只写到了设计,时间就不早了,今天把具体的实现,还有实现过程中所遇到的所有的问题写在这里。
引入依赖
这次我用了Spark2.x的java api,并且了解到spark底层是scala实现了,然后上层的api有scala版本和java版本,这里我使用了它提供的java的api,并且java底层调用的函数都是scala实现的,非常的方便,可以与java进行无缝的操作。
spark MLlib依赖:
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.12</artifactId>
<version>2.4.4</version>
<exclusions>
<exclusion>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</exclusion>
</exclusions>
</dependency>
使用SparkSession链接Spark集群
//初始化spark初始环境
SparkSession spark = SparkSession.builder().master("local").appName("BlogStormApp").getOrCreate();
我们连接的是本地的master节点,在连接时出现了缺少依赖的问题
经过查阅资料之后,添加了两个依赖
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>14.0.1</version>
</dependency>
<dependency>
<groupId>org.codehaus.janino</groupId>
<artifactId>janino</artifactId>
<version>3.0.8</version>
</dependency>
这连个依赖不是我们写的代码使用的,好像是Spark内部调用的方法使用的,缺少的话确实程序会运行失败的。
模型的训练
(1)使用spark的方法读取csv数据,然后序列化为javaRDD的格式
/** 内部序列化的静态类
* 数据模型 用户 分类 点击量
*/
public static class Rating implements Serializable{
private int userId;
private int itemId;
private int rating;
public static Rating parseRating(String str){
// str = str.replace("\"","");
String[] strArr = str.split(",");
int userId = Integer.parseInt(strArr[0]);
int itemId = Integer.parseInt(strArr[1]);
int rating = Integer.parseInt(strArr[2]);
return new Rating(userId,itemId,rating);
}
public Rating()
{
}
public Rating(int userId, int itemId, int rating) {
this.userId = userId;
this.itemId = itemId;
this.rating = rating;
}
public int getUserId() {
return userId;
}
public int getItemId() {
return itemId;
}
public int getRating() {
return rating;
}
}
JavaRDD<String> csvFile = spark.read().textFile(DATA_DIR).toJavaRDD();
JavaRDD<Rating> ratingJavaRDD = csvFile.map(new Function<String, Rating>() {
@Override
public Rating call(String s) {
return Rating.parseRating(s);
}
});
(2)将JavaRDD的格式转化为DataFrame的格式(DataFrame中的列可以是存储的文本,特征向量,真实标签和预测的标签等),然后切分训练集和测试集,方便验证模型。
然后调整参数,调用方法进行模型的训练
Dataset<Row> rating = spark.createDataFrame(ratingJavaRDD, Rating.class);
//将所有的rating数据分成82分
Dataset<Row>[] splits = rating.randomSplit(new double[]{0.8, 0.2});
Dataset<Row> trainingData = splits[0];
Dataset<Row> testingData = splits[1];
//5个特征,
//解决过拟合:1增大数据规模,减少rank,增大正则化系数
//欠拟合:增加rank,减少正则化系数
ALS als = new ALS().setMaxIter(10).setRank(5).setRegParam(0.01).setUserCol("userId")
.setItemCol("itemId").setRatingCol("rating");
ALSModel alsModel = als.fit(trainingData);
//模型评测 将测试集使用模型做一次转化的预测
Dataset<Row> predictions = alsModel.transform(testingData);
//rmse 均方根误差,预测值与真实值的偏差的平方除以观测次数,开个根号
RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse")
.setLabelCol("rating").setPredictionCol("prediction");
double rmse = evaluator.evaluate(predictions);
System.out.println("rmse" + rmse);
(3)训练结果以及参数的调整(使用均方根误差来衡量)
均方根误差是预测值与真实值偏差的平方与观测次数n比值的平方根,在实际测量中,观测次数n总是有限的,真值只能用最可信赖(最佳)值来代替。精密度。这正是标准误差在工程测量中广泛被采用的原因。因此,标准差是用来衡量一组数自身的离散程度,而均方根误差是用来衡量观测值同真值之间的偏差,它们的研究对象和研究目的不同,但是计算过程类似 。
参数调整的原则:
解决过拟合:1增大数据规模,减少rank,增大正则化系数
欠拟合:增加rank,减少正则化系数
ALS als = new ALS().setMaxIter(10).setRank(5).setRegParam(0.01).setUserCol("userId")
.setItemCol("itemId").setRatingCol("rating");
当rank为5 正则化系数为0.01时的结果:
ALS als = new ALS().setMaxIter(10).setRank(7).setRegParam(0.01).setUserCol("userId")
.setItemCol("itemId").setRatingCol("rating");
当rank为7 正则化系数为0.01时的结果(发现提升rank值会增大误差,所以需要减少rank值):
ALS als = new ALS().setMaxIter(10).setRank(4).setRegParam(0.01).setUserCol("userId")
.setItemCol("itemId").setRatingCol("rating");
当rank为4 正则化系数为0.01时的结果(减小rank值之后,发现误差变小):
ALS als = new ALS().setMaxIter(10).setRank(3).setRegParam(0.02).setUserCol("userId")
.setItemCol("itemId").setRatingCol("rating");
当rank为3 正则化系数为0.02时的结果(提升正则化系数,误差变大了,所以降低一下正则化系数):
ALS als = new ALS().setMaxIter(10).setRank(4).setRegParam(0.005).setUserCol("userId")
.setItemCol("itemId").setRatingCol("rating");
当rank为4 正则化系数为0.005时的结果(减少正则化系数 误差经一步提高 ):
ALS als = new ALS().setMaxIter(10).setRank(3).setRegParam(0.01).setUserCol("userId")
.setItemCol("itemId").setRatingCol("rating");
当rank为3 正则化系数为0.01时的结果( rank为3时 误差比rank为4大,所以重新调整rank):
ALS als = new ALS().setMaxIter(10).setRank(4).setRegParam(0.009).setUserCol("userId")
.setItemCol("itemId").setRatingCol("rating");
当rank为4 正则化系数为0.09时的结果( 重新调整rank为4 正则化系数为0.09,效果不错):
ALS als = new ALS().setMaxIter(10).setRank(4).setRegParam(0.07).setUserCol("userId")
.setItemCol("itemId").setRatingCol("rating");
当rank为4 正则化系数为0.07时的结果( 正则化系数为0.07,误差变大):
ALS als = new ALS().setMaxIter(10).setRank(4).setRegParam(0.1).setUserCol("userId")
.setItemCol("itemId").setRatingCol("rating");
当rank为4 正则化系数为0.1时的结果( 正则化系数为0.1,误差变大):
所以还是rank为4,且正则化系数为0.09比较合适
(4)模型的存储
/**
* 删除文件夹下所有文件
*
* @param file
*/
private static void deleteFile(File file) {
//判断是否为文件,是,则删除
if (file.isFile()) {
file.delete();
} else {//不为文件,则为文件夹
//获取文件夹下所有文件相对路径
String[] childFilePath = file.list();
for (String path : childFilePath) {
File childFile = new File(file.getAbsoluteFile() + "/" + path);
//递归,对每个都进行判断
deleteFile(childFile);
}
file.delete();
}
}
//删除文件夹下所有文件
File file = new File(MODEL_DIR);
deleteFile(file);
//缓存模型
alsModel.save(MODEL_DIR);
读取模型进行预测
(1)模型的读取
//初始化spark运行环境
SparkSession spark = SparkSession.builder().master("local").appName("BlogStormApp").getOrCreate();
//将模型加载到内存之中
ALSModel alsModel = ALSModel.load(MODEL_DIR);
(2)进行预测
JavaRDD<String> csvFile = spark.read().textFile(DATA_DIR).toJavaRDD();
JavaRDD<Rating> ratingJavaRDD = csvFile.map(new Function<String, Rating>() {
@Override
public Rating call(String s) throws Exception {
return Rating.parseRating(s);
}
});
Dataset<Row> rating = spark.createDataFrame(ratingJavaRDD,Rating.class);
//用户做离线的召回结果预测
Dataset<Row> users = rating.select(alsModel.getUserCol());
Dataset<Row> userRecs = alsModel.recommendForUserSubset(users,numItems).distinct().limit(userNumber);
userRecs.foreachPartition(new ForeachPartitionFunction<Row>() {
@Override
public void call(Iterator<Row> iterator) throws Exception {
List<String> data = new ArrayList<>();
iterator.forEachRemaining(action->{
int userId = action.getInt(0);
System.out.println(userId);
List<GenericRowWithSchema> recommendationList = action.getList(1);
List<Integer> shopIdList = new ArrayList<Integer>();
recommendationList.forEach(row->{
Integer tagId = row.getInt(0);
shopIdList.add(tagId);
});
String recommendData = StringUtils.join(shopIdList,"#");
System.out.println(recommendData);
data.add(String.valueOf(userId) +","+ recommendData);
// storeRecommendTag(userId,recommendData);
});
//删除文件夹下所有文件
File file = new File(PREDICT_DIR);
deleteFile(file);
csv.writeCsvFile(PREDICT_DIR,data);
}
});
问题:
本来在预测完结果后,希望直接调用redis,进行数据的存储操作,但是发现在链接Spark的代码文件中,使用@Resource进行注入时,回报没有序列化的错误
尝试各种方式,将调用的方法全部实现序列化的接口,但还是不行
最后,只能放弃在这个文件中进行依赖的注入。我先讲预测的结果存入csv文件中,在使用其他方法进行redis的写入。
推荐结果的写入和读取
在外部文件中调用相应的程序,实现模型的训练,然后读取预测好的结果,存入redis,并且预留调用的接口,方便进行调用
@Component
public class AlsRecommended {
/**
*数据存放的位置
*/
private static String TAG_DATA_DIR = "src/main/resources/UserBrowsingData/user-tag/user_tag.csv";
/**
* 模型的存放位置
*/
private static String TAG_MODEL_DIR = "src/main/resources/UserBrowsingData/user-tag/model";
/**
* 预测结果存放的位置
*/
private static String TAG_PREDICE_DIR = "src/main/resources/UserBrowsingData/user-tag/predict/predict.csv";
/**
* redis的公共key 然后与userId进行拼接
*/
private static String TAG_REDIS_KEY = "alsTagRecommended";
/**
*数据存放的位置
*/
private static String CATEGORY_DATA_DIR = "src/main/resources/UserBrowsingData/user-category/user_category.csv";
/**
* 模型的存放位置
*/
private static String CATEGORY_MODEL_DIR = "src/main/resources/UserBrowsingData/user-category/model";
/**
* 预测结果存放的位置
*/
private static String CATEGORY_Predict_DIR = "src/main/resources/UserBrowsingData/user-category/predict/predict.csv";
/**
* redis的公共key 然后与userId进行拼接
*/
private static String CATEGORY_REDIS_KEY = "alsCategoryRecommended";
@Resource
private AlsRecall alsRecall;
@Resource
private UserService userService;
@Resource
private RedisUtils redisUtils;
/**
* 初始化tag的推荐结果
* @throws IOException
*/
public void initTagRecommendedResult() throws IOException {
alsRecall.trainAls(TAG_DATA_DIR,TAG_MODEL_DIR);
int userNum = userService.getUserCount();
alsRecall.predict(TAG_MODEL_DIR,TAG_DATA_DIR,TAG_PREDICE_DIR,userNum+1,10);
List<String> recommendedTags = csv.readCsvFile(TAG_PREDICE_DIR);
if(recommendedTags != null)
{
for(String str :recommendedTags)
{
String list [] = str.split(",");
storeRecommendTag(Integer.valueOf(list[0]),list[1]);
}
}
}
/**
* 初始化category的推荐结果
* @throws IOException
*/
public void initCategoryRecommendedResult() throws IOException {
alsRecall.trainAls(CATEGORY_DATA_DIR,CATEGORY_MODEL_DIR);
int userNum = userService.getUserCount();
alsRecall.predict(CATEGORY_MODEL_DIR,CATEGORY_DATA_DIR,CATEGORY_Predict_DIR,userNum+1,3);
List<String> recommendedCategorys = csv.readCsvFile(CATEGORY_Predict_DIR);
if(recommendedCategorys != null)
{
for(String str :recommendedCategorys)
{
String list [] = str.split(",");
storeRecommendCategory(Integer.valueOf(list[0]),list[1]);
}
}
}
/**
* 存储标签的推荐的结果到redis
* @param userId
* @param recommendData 为tag的列表 用","分割 存入redis
*/
public void storeRecommendTag(int userId,String recommendData)
{
redisUtils.set(TAG_REDIS_KEY+String.valueOf(userId),recommendData);
}
/**
*从redis中读取标签的推荐结果
* @param userId
* @return
*/
public String getRecommendTag(int userId)
{
return redisUtils.get(TAG_REDIS_KEY+String.valueOf(userId));
}
/**
*从redis中读取分类的推荐结果
* @param userId
* @return
*/
public String getRecommendCategory(int userId)
{
return redisUtils.get(CATEGORY_REDIS_KEY+String.valueOf(userId));
}
/**
* 存储分类的推荐的结果到redis
* @param userId
* @param recommendData 为tag的列表 用","分割 存入redis
*/
public void storeRecommendCategory(int userId,String recommendData)
{
redisUtils.set(CATEGORY_REDIS_KEY+String.valueOf(userId),recommendData);
}
}
实现文章的推荐
读入redis提前推荐的标签,然后使用标签找到相应的文章进行读取
/**
* 基于ALS的 推荐标签 进而推荐文章
* @param userId
* @return
*/
public List<Integer> recommendArticleByTagALS(int userId)
{
List<Integer> articleList = null;
try {
//推荐的标签
String tagString = alsRecommended.getRecommendTag(userId);
//将标签分割为数组
String temp[] = tagString.split("#");
List<Integer> tagIdList = new ArrayList<>();
for (int i = 0; i < temp.length; i++) {
tagIdList.add(Integer.valueOf(temp[i]));
}
articleList = new ArrayList<>();
//根据标签 随机选出20篇文章
if (tagIdList.size() != 0) {
articleList = articleTagDao.selectArticleByTagList(userId, tagIdList, 20);
}
}
catch (Exception e)
{articleList = new ArrayList<>();
}
return articleList;
}
public List<Integer> recommendArticleByCategoryALS(int userId)
{
List<Integer> articleList = null;
try {
//推荐的标签
String categoryString = alsRecommended.getRecommendCategory(userId);
//将标签分割为数组
String temp [] = categoryString.split("#");
List<Integer> categoryIdList = new ArrayList<>();
for(int i = 0 ; i < temp.length ; i++)
{
categoryIdList.add(Integer.valueOf(temp[i]));
}
articleList = new ArrayList<>();
//根据标签 随机选出20篇文章
if(categoryIdList.size() != 0)
{
articleList = articleCategoryDao.selectArticleByCategoryList(userId,categoryIdList,20);
}
}
catch (Exception e)
{
articleList = new ArrayList<>();
}
return articleList;
}
总结
这就是基于spark进行推荐的全部内容。这个推荐是离线推荐的,需要在每天晚上定时的抽取用户的浏览记录,然后使用Sprak提前对推荐的物品进行预计算,然后存入redis中,之后一天的推荐都是基于redis数据的离线推荐。