召回(ALS)接入
之前离线召回的数据保存在了 MySQL 中;
直接取出来就行了;
package tech.lixinlei.dianping.recommand;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import tech.lixinlei.dianping.dal.RecommendModelMapper;
import tech.lixinlei.dianping.model.RecommendModel;
@Service
public class RecommendService{
@Autowired
private RecommendModelMapper recommendModelMapper;
/**
* 召回数据, 根据 userId 召回 shopIdList
* @param userId
* @return
*/
public List recall(Integer userId){
RecommendModel recommendModel = recommendModelMapper.selectByPrimaryKey(userId);
if(recommendModel == null){
recommendModel = recommendModelMapper.selectByPrimaryKey(9999999);
}
String[] shopIdArr = recommendModel.getRecommend().split(",");
List shopIdList = new ArrayList<>();
for(int i = 0; i < shopIdArr.length; i++) {
shopIdList.add(Integer.valueOf(shopIdArr[i]));
}
return shopIdList;
}
}
排序(LR)接入
package tech.lixinlei.dianping.recommand;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.SparkSession;
import org.springframework.stereotype.Service;
import javax.annotation.PostConstruct;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
@Service
public class RecommendSortService {
private SparkSession spark;
private LogisticRegressionModel lrModel;
@PostConstruct
public void init(){
//加载 LR 模型
spark = SparkSession.builder().master("local").appName("DianpingApp").getOrCreate();
lrModel = LogisticRegressionModel.load("file:///home/lixinlei/project/gitee/dianping/src/main/resources/lrmode");
}
public List sort(List shopIdList, Integer userId){
// 需要根据 lrmode 所需要 11 维的 x,生成特征,然后调用其预测方法
List list = new ArrayList<>();
for(Integer shopId : shopIdList){
//造的假数据,可以从数据库或缓存中拿到对应的性别,年龄,评分,价格等做特征转化生成 feature 向量
Vector v = Vectors.dense(1,0,0,0,0,1,0.6,0,0,1,0);
Vector result = lrModel.predictProbability(v);
// arr[1] 表示代表正样本的概率
double[] arr = result.toArray();
double score = arr[1];
ShopSortModel shopSortModel = new ShopSortModel();
shopSortModel.setShopId(shopId);
shopSortModel.setScore(score);
list.add(shopSortModel);
}
list.sort(new Comparator() {
@Override
public int compare(ShopSortModel o1, ShopSortModel o2) {
if(o1.getScore() < o2.getScore()){
return 1;
}else if(o1.getScore() > o2.getScore()){
return -1;
}else{
return 0;
}
}
});
return list.stream().map(shopSortModel -> shopSortModel.getShopId()).collect(Collectors.toList());
}
}
修改原来的 recommand 方法的实现
先召回,再排序;
package tech.lixinlei.dianping.service.impl;
@Service
public class ShopServiceImpl implements ShopService {
@Autowired
RecommendService recommendService;
@Autowired
RecommendSortService recommendSortService;
/**
* 先召回,再排序
* @param longitude
* @param latitude
* @return
*/
@Override
public List recommend(BigDecimal longitude, BigDecimal latitude) {
List shopIdList = recommendService.recall(148);
shopIdList = recommendSortService.sort(shopIdList,148);
List shopModelList = shopIdList.stream().map(id->{
ShopModel shopModel = get(id);
shopModel.setIconUrl("/static/image/shopcover/xchg.jpg");
shopModel.setDistance(100);
return shopModel;
}).collect(Collectors.toList());
// List shopModelList = shopModelMapper.recommend(longitude, latitude);
// shopModelList.forEach(shopModel -> {
// shopModel.setSellerModel(sellerService.get(shopModel.getSellerId()));
// shopModel.setCategoryModel(categoryService.get(shopModel.getCategoryId()));
// });
return shopModelList;
}
}