lr Java user,Java 接入 ALS & LR 为用户推荐商户

召回(ALS)接入

之前离线召回的数据保存在了 MySQL 中;

直接取出来就行了;

package tech.lixinlei.dianping.recommand;

import java.io.Serializable;

import java.util.ArrayList;

import java.util.List;

import org.springframework.beans.factory.annotation.Autowired;

import org.springframework.stereotype.Service;

import tech.lixinlei.dianping.dal.RecommendModelMapper;

import tech.lixinlei.dianping.model.RecommendModel;

@Service

public class RecommendService{

@Autowired

private RecommendModelMapper recommendModelMapper;

/**

* 召回数据, 根据 userId 召回 shopIdList

* @param userId

* @return

*/

public List recall(Integer userId){

RecommendModel recommendModel = recommendModelMapper.selectByPrimaryKey(userId);

if(recommendModel == null){

recommendModel = recommendModelMapper.selectByPrimaryKey(9999999);

}

String[] shopIdArr = recommendModel.getRecommend().split(",");

List shopIdList = new ArrayList<>();

for(int i = 0; i < shopIdArr.length; i++) {

shopIdList.add(Integer.valueOf(shopIdArr[i]));

}

return shopIdList;

}

}

排序(LR)接入

package tech.lixinlei.dianping.recommand;

import org.apache.spark.ml.classification.GBTClassificationModel;

import org.apache.spark.ml.classification.LogisticRegressionModel;

import org.apache.spark.ml.linalg.Vector;

import org.apache.spark.ml.linalg.Vectors;

import org.apache.spark.sql.SparkSession;

import org.springframework.stereotype.Service;

import javax.annotation.PostConstruct;

import java.util.ArrayList;

import java.util.Comparator;

import java.util.List;

import java.util.stream.Collectors;

@Service

public class RecommendSortService {

private SparkSession spark;

private LogisticRegressionModel lrModel;

@PostConstruct

public void init(){

//加载 LR 模型

spark = SparkSession.builder().master("local").appName("DianpingApp").getOrCreate();

lrModel = LogisticRegressionModel.load("file:///home/lixinlei/project/gitee/dianping/src/main/resources/lrmode");

}

public List sort(List shopIdList, Integer userId){

// 需要根据 lrmode 所需要 11 维的 x,生成特征,然后调用其预测方法

List list = new ArrayList<>();

for(Integer shopId : shopIdList){

//造的假数据,可以从数据库或缓存中拿到对应的性别,年龄,评分,价格等做特征转化生成 feature 向量

Vector v = Vectors.dense(1,0,0,0,0,1,0.6,0,0,1,0);

Vector result = lrModel.predictProbability(v);

// arr[1] 表示代表正样本的概率

double[] arr = result.toArray();

double score = arr[1];

ShopSortModel shopSortModel = new ShopSortModel();

shopSortModel.setShopId(shopId);

shopSortModel.setScore(score);

list.add(shopSortModel);

}

list.sort(new Comparator() {

@Override

public int compare(ShopSortModel o1, ShopSortModel o2) {

if(o1.getScore() < o2.getScore()){

return 1;

}else if(o1.getScore() > o2.getScore()){

return -1;

}else{

return 0;

}

}

});

return list.stream().map(shopSortModel -> shopSortModel.getShopId()).collect(Collectors.toList());

}

}

修改原来的 recommand 方法的实现

先召回,再排序;

package tech.lixinlei.dianping.service.impl;

@Service

public class ShopServiceImpl implements ShopService {

@Autowired

RecommendService recommendService;

@Autowired

RecommendSortService recommendSortService;

/**

* 先召回,再排序

* @param longitude

* @param latitude

* @return

*/

@Override

public List recommend(BigDecimal longitude, BigDecimal latitude) {

List shopIdList = recommendService.recall(148);

shopIdList = recommendSortService.sort(shopIdList,148);

List shopModelList = shopIdList.stream().map(id->{

ShopModel shopModel = get(id);

shopModel.setIconUrl("/static/image/shopcover/xchg.jpg");

shopModel.setDistance(100);

return shopModel;

}).collect(Collectors.toList());

// List shopModelList = shopModelMapper.recommend(longitude, latitude);

// shopModelList.forEach(shopModel -> {

// shopModel.setSellerModel(sellerService.get(shopModel.getSellerId()));

// shopModel.setCategoryModel(categoryService.get(shopModel.getCategoryId()));

// });

return shopModelList;

}

}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值