推荐系统(工程方向)-策略平台

一、背景

假设某个app的首页推荐,有2个策略产品经理、6个算法RD、2个工程RD协同工作,如果没有一套可视化的策略执行流程,沟通将非常低效,且对接容易出错,达不到想要的效果。

其次推荐系统一般可抽象出几个核心流程:

  1. 获取用户画像+用户过滤列表
  2. 召回
  3. 排序
  4. 重排
  5. 人工运营

这些核心流程可抽成公共组件,配置在流程中,减少很多工作量。

二、方案设计

1、设计思路

  • 使用DAG(有向无环图)构建用户请求的处理模块(模块=DAG节点=策略)
  • 使用一个数据流对象(DataFlow),串联DAG所有节点
  • 使用锁控制并发
  • 通过UI修改DAG配置,近实时生效
  • 支持同一个DAG中多个同名模块调用
  • 支持子DAG调用

2、设计图

简易架构图

DAG编辑页面效果图

3、核心代码

Node类

@Setter
@Getter
@ToString(exclude = {"nextEdges"})
public class Node {
    //构图相关
    private String key;// 节点标识
    private List<Edge> nextEdges; //依赖它的节点
    private int inDegree = 0;// 该节点的入度
    private int inDegreeCheck = 0;// 用于使用前检测,特别注意

    private GraphModel graphModel;//模型接口,在执行时,把内存中实例化好的类装入。为null,表示忽略此节点执行

    public Node(String key) {
        this.key = key;
        nextEdges = new LinkedList<Edge>();
    }

    public int addInDegree() {
        return inDegree = inDegree + 1;
    }

    public int decreaseInDegree() {
        return inDegree = inDegree - 1;
    }

    public int addInDegreeCheck() {
        return inDegreeCheck = inDegreeCheck + 1;
    }

    public int decreaseInDegreeCheck() {
        return inDegreeCheck = inDegreeCheck - 1;
    }

}

Edge类

@Setter
@Getter
@ToString
public class Edge {
    private Node endNode;

    public Edge(Node endNode) {
        this.endNode = endNode;
    }
}

DirectedGraph类

/*
 * 有向无环图
 */
@Setter
@Getter
public class DirectedGraph {

    private Map<String, Node> directedGraphMap;

    public DirectedGraph(String graphContent) {
        directedGraphMap = new LinkedHashMap<String, Node>();
        buildGraph(graphContent);
    }

    //构图 解析文件
    private void buildGraph(String graphContent) {
        graphContent = graphContent.replaceAll("\n", "`").replaceAll("\\s*", "").replaceAll("`", "\n");
        String[] lines = graphContent.split("\n");
        Node startNode, endNode;
        String startNodeLabel, endNodeLabel;
        Edge edge;
        for (int i = 0; i < lines.length; i++) {
            String[] nodesInfo = lines[i].split(",");
            if (nodesInfo.length != 2) {
                throw new RuntimeException((i + 1) + "行包含" + nodesInfo.length + "节点,每行只能2个节点!");
            }
            startNodeLabel = nodesInfo[0];
            endNodeLabel = nodesInfo[1];
            startNode = directedGraphMap.get(startNodeLabel);
            if (startNode == null) {
                startNode = new Node(startNodeLabel);
                directedGraphMap.put(startNodeLabel, startNode);
            }
            endNode = directedGraphMap.get(endNodeLabel);
            if (endNode == null) {
                endNode = new Node(endNodeLabel);
                directedGraphMap.put(endNodeLabel, endNode);
            }

            edge = new Edge(endNode);//每读入一行代表一条边
            startNode.getNextEdges().add(edge);//每读入一行数据,起始顶点添加一条边
            endNode.addInDegree();//每读入一行数据,终止顶点入度加1
            endNode.addInDegreeCheck();
        }
    }

    /**
     * 判断图是否规范
     *
     * @return
     */
    public boolean validate() {
        int count = 0;
        //初始化队列
        Queue<Node> queue = new LinkedList<>();// 拓扑排序中用到的栈,也可用队列.
        //扫描所有的顶点,将入度为0的顶点入队列
        Collection<Node> nodes = directedGraphMap.values();
        for (Node node : nodes)
            if (node.getInDegreeCheck() == 0) {
                queue.offer(node);
            }
        //执行算法,维护队列
        while (!queue.isEmpty()) {
            Node vistNode = queue.poll();
            //统计已执行的节点个数
            count++;
            //判断依赖他的节点入度是否为零
            for (Edge edge : vistNode.getNextEdges()) {
                if (edge.getEndNode().decreaseInDegreeCheck() == 0)
                    queue.offer(edge.getEndNode());
            }
        }
        if (count != nodes.size())
            return false;
        return true;
    }
}

GraphModelExecutor(执行引擎execute方法,广度遍历dag)

public GraphState execute(final DataFlow dataFlow) {
        //dag状态记录类
        final GraphState graphState = new GraphState();
        Stopwatch stopwatch = Stopwatch.createStarted();
        //abtest拦截器
        aBtestInterceptor.interceptor(null, dataFlow);
        //业务封装project类(业务code+场景code+abtest可定位到project)
        final Project project = projectManager.getProjectMap().get(dataFlow.getProjectKey());
        Preconditions.checkNotNull(project, dataFlow.getProjectKey() + " has no project!");
        //pageSize拦截
        pageSizeInterceptor.interceptor(project, dataFlow);
        //count拦截,注入冗余设计
        countInterceptor.interceptor(project, dataFlow);
        //itemType拦截
        itemTypeInterceptor.interceptor(project, dataFlow);
        final Queue<List<Node>> queue = new LinkedBlockingQueue<>();// 拓扑排序中用到的栈,也可用队列
        DirectedGraph directedGraph = new DirectedGraph(project.getDagScript());
        graphState.setDagScript(project.getDagScript());//先记录图脚本
        if (!directedGraph.validate()) {
            log.error("【{}】requestId={}`userId={},deviceId={},dagScript={},directedGraph is not right!", dataFlow.getProjectKey(), dataFlow.getRequestId(), dataFlow.getUserId(), dataFlow.getDeviceId(), project.getDagScript());
            return graphState;
        }
        //获取所有模型,并实例化类
        Collection<Node> nodes = directedGraph.getDirectedGraphMap().values();
        for (Node node : nodes) {
            //1、是普通模块 2、是DAG图模块【注意:如果先提交了图脚本,代码忘记提交,找不到model,忽略并报警】
            GraphModel graphModel;
            if (GraphModelUtil.isDagModel(node.getKey())) {
                graphModel = graphModelManager.getModel(GraphModelManager.getSimpleName(DAGModel.class.getSimpleName()));
                String bzCodeSceneCode = node.getKey().replaceAll("\\$", "");
                if (bzCodeSceneCode.equals(dataFlow.getBzCodeSceneCodeKey())) {
                    log.error("构图错误,不能把自己作为一个图模块={}", node.getKey());
                    return graphState;
                }
            } else if (BaseCommonUtil.isMultiInstanceModel(node.getKey())) {
                //ModelA:a获取ModelA
                String[] modelKeyArr = node.getKey().split(BaseConsts.modelInstanceSpliter, 2);
                graphModel = graphModelManager.getModel(modelKeyArr[0]);
            } else {
                graphModel = graphModelManager.getModel(node.getKey());
            }
            if (graphModel != null) {
                //判断当前模块熔断状态
                if (checkModelStat(project, node.getKey())) {
                    node.setGraphModel(graphModel);
                }
            } else {
                String env = projectManager.getEnv().endsWith(Consts.env.pre) ? "预发" : "线上";
                log.error("className=" + node.getKey() + " has no model!");
                CommonUtil.sendWarningMsg(env + "环境," + node.getKey() + "模块异常", "推荐平台", HttpWechat.INFO_LEVEL_ERROR
                        , dataFlow.getProjectKey() + "的" + node.getKey() + "模块未找到实现类,可能代码未提交或模块实例化异常!", "");
            }
        }
        //扫描所有的顶点,将入度为0的顶点入队列(可能有多个顶点)
        List<Node> firstNodeList = new ArrayList<>();
        for (Node node : nodes) {
            if (node.getInDegree() == 0) {
                //最开始的任务,入队
                firstNodeList.add(node);
            }
        }
        queue.offer(firstNodeList);
        //执行算法,维护队列
        while (!queue.isEmpty()) {
            final List<Node> visitNodeList = queue.poll();
            if (visitNodeList == null || visitNodeList.size() == 0) {
                log.error("【{}】requestId={}`userId={},deviceId={} 该层没有任何节点!", dataFlow.getProjectKey(), dataFlow.getRequestId(), dataFlow.getUserId(), dataFlow.getDeviceId());
                continue;
            }
            //主线程中add下一次运行的nodeList
            setNextNodeList(visitNodeList, queue);
            final boolean isParallel = visitNodeList.size() > 1 ? true : false;
            //单节点用于日志区分(不加不易区分)
            if (isParallel) {
                if (Config.getConfig().getParallelWay() == 1) {
                    countDownLatch(project, visitNodeList, dataFlow, isParallel, graphState);
                } else if (Config.getConfig().getParallelWay() == 2) {
                    completableFuture(project, visitNodeList, dataFlow, isParallel, graphState);
                } else {
                    completeService(project, visitNodeList, dataFlow, isParallel, graphState);
                }
            } else {
                Node node = visitNodeList.get(0);
                long nodeTimeout = getTimeout(project, node.getKey());
                dealOneNode(project, visitNodeList.get(0), dataFlow, isParallel, graphState, nodeTimeout, true);
            }
        }
        //preject执行完,设置item_type
        itemTypeAfterInterceptor.interceptor(project, dataFlow);
        recordFusingMsgByGraphState(project, graphState);
        log.info("【{}】requestId={}`userId={},deviceId={},executed all_nodes spends={}", project.getProjectKey(), dataFlow.getRequestId(), dataFlow.getUserId(), dataFlow.getDeviceId(), stopwatch.elapsed(TimeUnit.MILLISECONDS));
        return graphState;
    }

DataFlow类

/**
 * 数据流对象
 */
@Setter
@Getter
@ToString(exclude = {"rwLock", "version", "filterList", "rcmdList", "userProf"})
@Slf4j
public class DataFlow {
    /***必传字段***/
    private final List<Long> filterList = new FilterCollection<>();       //过滤列表
    private ReadWriteLock rwLock = new ReentrantReadWriteLock();//对于有竞态条件的使用锁
    private AtomicLong version = new AtomicLong();//记录当前版本号
    /***必传字段***/
    private long userId;
    private String deviceId;
    private int vercode;
    private String deviceType;
    private int appID;
    private long subAppId;
    private String bzCode;//业务对应的实体entity
    private String sceneCode;//用户所在场景/页面
    private int pageNo;//当前页码
    private int pageSize;//一页数据
    private int count;//pageNumber+more_count(冗余条数,为了更好效果,额外加多的召回量)
    private List<ItemScore> rcmdList = new ArrayList<>();    //推荐数据
    private UserProfile userProf = new UserProfile();    //用户画像
    private UserProfile njUserProf;    //主播侧的用户画像
    private List<ItemProfile> itemExpo;           //当前曝光的ITEM(如果有)
    private String abtest;//ab标识
    private boolean isDag;//标记是dag图调用,只允许调用基础组(A)
    private Map<String, Object> extraMap = new HashMap<>();      //扩展数据,模块间传输数据使用
    private JSONObject extraJson = new JSONObject();//扩展对象,用来接收协议中extra json传参
    private JSONObject resultExtraJson = new JSONObject();//扩展对象,用来接整个请求,与List<RcmdItem>同级别的扩展结果,比如直播dataversiontime
    private boolean isOriginal;//是否只返回原生数据,比如首页播单接口返回播单+专题+...但是有些情况下只需要返回播单
    private int itemType;
    private String requestId;//请求唯一id,用于日志排查
    private String hitExps;//abtest平台命中的所有实验,eg:exp1|exp2|exp3
    private Map<String, String> expParamMap = new HashMap<>();//只读,实验平台获取到的配置,key=modelName
    //存在并发读写,实验平台:按条件处理之后,如果模块不启用,请add进来
    private Set<String> disableModelSet = Collections.newSetFromMap(new ConcurrentHashMap<String, Boolean>());
    private String baseInfo = null;

    public String version() {
        StringBuilder sb = new StringBuilder();
        sb.append(toBaseInfo())
                .append("`version=").append(this.getVersion().incrementAndGet())
                .append("`filterList=").append(filterList.size())
                .append("`rcmdList=").append(rcmdList.size());
        return sb.toString();
    }

    /**
     * 与project配置相对应,用于每个请求,找到对应的project
     *
     * @return
     */
    public String getProjectKey() {
        String projectKey = bzCode + "_" + sceneCode;
        if (abtest != null && abtest.trim().length() > 0) {
            projectKey = bzCode + "_" + sceneCode + "_" + abtest;
        }
        return projectKey;
    }

    /**
     * 打印基础数据
     *
     * @return
     */
    public String toBaseInfo() {
        if (baseInfo == null) {
            String sb = "userid=" + userId +
                    "`deviceid=" + deviceId +
                    "`projectKey=" + getProjectKey() +
                    "`pageNo=" + pageNo +
                    "`pageSize=" + pageSize +
                    "`hitExps=" + hitExps +
                    "`requestId=" + requestId;

            baseInfo = sb;
        }

        return baseInfo;
    }

    public void addExtraMap(String key, Object val) {
        if (extraMap.containsKey(key)) {
            log.warn("key already exist`key={}`oldVal={}`newVal={}", key, extraMap.get(key), val);
        }
        extraMap.put(key, val);
    }

    public <T> T getFromExtraMap(String key) {
        Object value = extraMap.get(key);
        if (value == null) {
            return null;
        }

        return (T) value;
    }

    public <T> T getFromExtraMap(String key, T defaultValue) {
        Object value = extraMap.get(key);
        if (value == null) {
            return defaultValue;
        }

        return (T) value;
    }

    /**
     * bzCode_sceneCode代表一个项目,abtest代表这个项目做了ab测试
     *
     * @return
     */
    public String getBzCodeSceneCodeKey() {
        return bzCode + "_" + sceneCode;
    }

    public List<Long> getItemIds() {
        List<Long> result = new ArrayList<>();

        for (ItemScore itemScore : rcmdList) {
            result.add(itemScore.getItemId());
        }

        return result;
    }

    public void setFilterList(List<Long> list) {
        filterList.clear();
        filterList.addAll(list);
    }

    /**
     * DAGModel模块使用,bzCode和sceneCode和pageSize在模块中注入
     * 此处只是为了构建查询条件,万不可 setRcmdList
     *
     * @return
     */
    public DataFlow copyForDagModel() {
        DataFlow dataFlow = new DataFlow();
        dataFlow.setRequestId(requestId);
        dataFlow.setUserId(userId);
        dataFlow.setDeviceId(deviceId);
        dataFlow.setVercode(vercode);
        dataFlow.setDeviceType(deviceType);
        dataFlow.setVercode(vercode);
        dataFlow.setPageNo(pageNo);
        dataFlow.setPageSize(pageSize);
        dataFlow.setCount(pageSize);
        dataFlow.setExtraJson(extraJson);//有可能用到这个
        dataFlow.setDag(true);//标记是dag图调用
        dataFlow.setItemExpo(itemExpo);
        dataFlow.setExtraMap(extraMap);
        dataFlow.setFilterList(new FilterCollection(filterList));

        if (this.userProf != null) {
            UserProfile userProfile = new UserProfile(new ConcurrentHashMap<>(this.userProf.getProperty()), new ConcurrentHashMap<>(this.userProf.getOriginalProperty()));
            userProfile.setUserType(this.userProf.getUserType());
            dataFlow.setUserProf(userProfile);
        }

        if (this.njUserProf != null) {
            UserProfile njUserProfile = new UserProfile(new ConcurrentHashMap<>(this.njUserProf.getProperty()), new ConcurrentHashMap<>(this.njUserProf.getOriginalProperty()));
            dataFlow.setNjUserProf(njUserProfile);
        }

        //分层实验参数传递
        dataFlow.setHitExps(hitExps);
        if (expParamMap != null) {
            dataFlow.setExpParamMap(new HashMap<>(expParamMap));
        }

        return dataFlow;
    }

    /**
     * 判断是否命中某个实验平台的实验
     *
     * @param exps
     * @return
     */
    public boolean containsHitExps(String... exps) {
        if (StringUtils.isBlank(hitExps)) {
            return false;
        }

        for (String exp : exps) {
            if (("|" + hitExps + "|").contains("|" + exp + "|")) {
                return true;
            }
        }

        return false;

    }

    /**
     * 获取当前用户的平台信息
     *
     * @return
     */
    public String getUserPlatform() {
        if (deviceType.startsWith("Android")) {
            return PLATFORM_ANDROID;
        } else if (deviceType.startsWith("IOS")) {
            return PLATFORM_IOS;
        } else if (deviceType.startsWith("h5-web")) {
            return PLATFORM_H5;
        } else {
            return PLATFORM_UNKNOW;
        }
    }

    public synchronized void addRecallProfile(String key, Object value) {
        UserProfile recallProfile = getFromExtraMap("RECALL_PROFILE");
        if (recallProfile == null) {
            recallProfile = new UserProfile();
            addExtraMap("RECALL_PROFILE", recallProfile);
        }

        recallProfile.setProperty(key, value);

    }

    public synchronized <T> T getRecallProfile(String key, T defaultValue) {
        UserProfile recallProfile = getFromExtraMap("RECALL_PROFILE");
        if (recallProfile == null) {
            return null;
        }

        Object value = recallProfile.getProperty(key);
        if (value == null) {
            return defaultValue;
        }

        return (T) value;
    }

    public synchronized <T> T getRecallProfile(String key) {
        return getRecallProfile(key, null);
    }

    /**
     * 添加Map类型的画像
     *
     * @param key   画像的key
     * @param value 画像的值,其中key为种子的值,value为该种子对应的分数
     */
    public void addRecallProfileMap(String key, Map<String, Float> value) {
        addRecallProfile(key, value);
    }

    /**
     * 添加JsonObject类型的召回画像
     *
     * @param key   画像的key
     * @param value 画像的值,格式:{"seed1":0.2,"seed2":0.1}
     *              其中key为种子的值,value为该种子对应的分数
     */
    public void addRecallProfileJSONObject(String key, JSONObject value) {
        addRecallProfile(key, value);
    }

    /**
     * 添加JsonArray类型的召回画像
     *
     * @param key   画像的key
     * @param value 画像的值,格式:[{"key1":"value1","key2":"value2"}]
     */
    public void addRecallProfileJSONArray(String key, JSONArray value) {
        addRecallProfile(key, value);
    }

    /**
     * 添加向量类型的召回画像
     *
     * @param key   画像的key
     * @param value 画像的值,格式:[0.1,-0.2,0.3,0.4]
     */
    public void addRecallProfileVector(String key, String value) {
        addRecallProfile(key, value);
    }
}

 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值