一、背景
假设某个app的首页推荐,有2个策略产品经理、6个算法RD、2个工程RD协同工作,如果没有一套可视化的策略执行流程,沟通将非常低效,且对接容易出错,达不到想要的效果。
其次推荐系统一般可抽象出几个核心流程:
- 获取用户画像+用户过滤列表
- 召回
- 排序
- 重排
- 人工运营
这些核心流程可抽成公共组件,配置在流程中,减少很多工作量。
二、方案设计
1、设计思路
- 使用DAG(有向无环图)构建用户请求的处理模块(模块=DAG节点=策略)
- 使用一个数据流对象(DataFlow),串联DAG所有节点
- 使用锁控制并发
- 通过UI修改DAG配置,近实时生效
- 支持同一个DAG中多个同名模块调用
- 支持子DAG调用
2、设计图
简易架构图
DAG编辑页面效果图
3、核心代码
Node类
@Setter
@Getter
@ToString(exclude = {"nextEdges"})
public class Node {
//构图相关
private String key;// 节点标识
private List<Edge> nextEdges; //依赖它的节点
private int inDegree = 0;// 该节点的入度
private int inDegreeCheck = 0;// 用于使用前检测,特别注意
private GraphModel graphModel;//模型接口,在执行时,把内存中实例化好的类装入。为null,表示忽略此节点执行
public Node(String key) {
this.key = key;
nextEdges = new LinkedList<Edge>();
}
public int addInDegree() {
return inDegree = inDegree + 1;
}
public int decreaseInDegree() {
return inDegree = inDegree - 1;
}
public int addInDegreeCheck() {
return inDegreeCheck = inDegreeCheck + 1;
}
public int decreaseInDegreeCheck() {
return inDegreeCheck = inDegreeCheck - 1;
}
}
Edge类
@Setter
@Getter
@ToString
public class Edge {
private Node endNode;
public Edge(Node endNode) {
this.endNode = endNode;
}
}
DirectedGraph类
/*
* 有向无环图
*/
@Setter
@Getter
public class DirectedGraph {
private Map<String, Node> directedGraphMap;
public DirectedGraph(String graphContent) {
directedGraphMap = new LinkedHashMap<String, Node>();
buildGraph(graphContent);
}
//构图 解析文件
private void buildGraph(String graphContent) {
graphContent = graphContent.replaceAll("\n", "`").replaceAll("\\s*", "").replaceAll("`", "\n");
String[] lines = graphContent.split("\n");
Node startNode, endNode;
String startNodeLabel, endNodeLabel;
Edge edge;
for (int i = 0; i < lines.length; i++) {
String[] nodesInfo = lines[i].split(",");
if (nodesInfo.length != 2) {
throw new RuntimeException((i + 1) + "行包含" + nodesInfo.length + "节点,每行只能2个节点!");
}
startNodeLabel = nodesInfo[0];
endNodeLabel = nodesInfo[1];
startNode = directedGraphMap.get(startNodeLabel);
if (startNode == null) {
startNode = new Node(startNodeLabel);
directedGraphMap.put(startNodeLabel, startNode);
}
endNode = directedGraphMap.get(endNodeLabel);
if (endNode == null) {
endNode = new Node(endNodeLabel);
directedGraphMap.put(endNodeLabel, endNode);
}
edge = new Edge(endNode);//每读入一行代表一条边
startNode.getNextEdges().add(edge);//每读入一行数据,起始顶点添加一条边
endNode.addInDegree();//每读入一行数据,终止顶点入度加1
endNode.addInDegreeCheck();
}
}
/**
* 判断图是否规范
*
* @return
*/
public boolean validate() {
int count = 0;
//初始化队列
Queue<Node> queue = new LinkedList<>();// 拓扑排序中用到的栈,也可用队列.
//扫描所有的顶点,将入度为0的顶点入队列
Collection<Node> nodes = directedGraphMap.values();
for (Node node : nodes)
if (node.getInDegreeCheck() == 0) {
queue.offer(node);
}
//执行算法,维护队列
while (!queue.isEmpty()) {
Node vistNode = queue.poll();
//统计已执行的节点个数
count++;
//判断依赖他的节点入度是否为零
for (Edge edge : vistNode.getNextEdges()) {
if (edge.getEndNode().decreaseInDegreeCheck() == 0)
queue.offer(edge.getEndNode());
}
}
if (count != nodes.size())
return false;
return true;
}
}
GraphModelExecutor(执行引擎execute方法,广度遍历dag)
public GraphState execute(final DataFlow dataFlow) {
//dag状态记录类
final GraphState graphState = new GraphState();
Stopwatch stopwatch = Stopwatch.createStarted();
//abtest拦截器
aBtestInterceptor.interceptor(null, dataFlow);
//业务封装project类(业务code+场景code+abtest可定位到project)
final Project project = projectManager.getProjectMap().get(dataFlow.getProjectKey());
Preconditions.checkNotNull(project, dataFlow.getProjectKey() + " has no project!");
//pageSize拦截
pageSizeInterceptor.interceptor(project, dataFlow);
//count拦截,注入冗余设计
countInterceptor.interceptor(project, dataFlow);
//itemType拦截
itemTypeInterceptor.interceptor(project, dataFlow);
final Queue<List<Node>> queue = new LinkedBlockingQueue<>();// 拓扑排序中用到的栈,也可用队列
DirectedGraph directedGraph = new DirectedGraph(project.getDagScript());
graphState.setDagScript(project.getDagScript());//先记录图脚本
if (!directedGraph.validate()) {
log.error("【{}】requestId={}`userId={},deviceId={},dagScript={},directedGraph is not right!", dataFlow.getProjectKey(), dataFlow.getRequestId(), dataFlow.getUserId(), dataFlow.getDeviceId(), project.getDagScript());
return graphState;
}
//获取所有模型,并实例化类
Collection<Node> nodes = directedGraph.getDirectedGraphMap().values();
for (Node node : nodes) {
//1、是普通模块 2、是DAG图模块【注意:如果先提交了图脚本,代码忘记提交,找不到model,忽略并报警】
GraphModel graphModel;
if (GraphModelUtil.isDagModel(node.getKey())) {
graphModel = graphModelManager.getModel(GraphModelManager.getSimpleName(DAGModel.class.getSimpleName()));
String bzCodeSceneCode = node.getKey().replaceAll("\\$", "");
if (bzCodeSceneCode.equals(dataFlow.getBzCodeSceneCodeKey())) {
log.error("构图错误,不能把自己作为一个图模块={}", node.getKey());
return graphState;
}
} else if (BaseCommonUtil.isMultiInstanceModel(node.getKey())) {
//ModelA:a获取ModelA
String[] modelKeyArr = node.getKey().split(BaseConsts.modelInstanceSpliter, 2);
graphModel = graphModelManager.getModel(modelKeyArr[0]);
} else {
graphModel = graphModelManager.getModel(node.getKey());
}
if (graphModel != null) {
//判断当前模块熔断状态
if (checkModelStat(project, node.getKey())) {
node.setGraphModel(graphModel);
}
} else {
String env = projectManager.getEnv().endsWith(Consts.env.pre) ? "预发" : "线上";
log.error("className=" + node.getKey() + " has no model!");
CommonUtil.sendWarningMsg(env + "环境," + node.getKey() + "模块异常", "推荐平台", HttpWechat.INFO_LEVEL_ERROR
, dataFlow.getProjectKey() + "的" + node.getKey() + "模块未找到实现类,可能代码未提交或模块实例化异常!", "");
}
}
//扫描所有的顶点,将入度为0的顶点入队列(可能有多个顶点)
List<Node> firstNodeList = new ArrayList<>();
for (Node node : nodes) {
if (node.getInDegree() == 0) {
//最开始的任务,入队
firstNodeList.add(node);
}
}
queue.offer(firstNodeList);
//执行算法,维护队列
while (!queue.isEmpty()) {
final List<Node> visitNodeList = queue.poll();
if (visitNodeList == null || visitNodeList.size() == 0) {
log.error("【{}】requestId={}`userId={},deviceId={} 该层没有任何节点!", dataFlow.getProjectKey(), dataFlow.getRequestId(), dataFlow.getUserId(), dataFlow.getDeviceId());
continue;
}
//主线程中add下一次运行的nodeList
setNextNodeList(visitNodeList, queue);
final boolean isParallel = visitNodeList.size() > 1 ? true : false;
//单节点用于日志区分(不加不易区分)
if (isParallel) {
if (Config.getConfig().getParallelWay() == 1) {
countDownLatch(project, visitNodeList, dataFlow, isParallel, graphState);
} else if (Config.getConfig().getParallelWay() == 2) {
completableFuture(project, visitNodeList, dataFlow, isParallel, graphState);
} else {
completeService(project, visitNodeList, dataFlow, isParallel, graphState);
}
} else {
Node node = visitNodeList.get(0);
long nodeTimeout = getTimeout(project, node.getKey());
dealOneNode(project, visitNodeList.get(0), dataFlow, isParallel, graphState, nodeTimeout, true);
}
}
//preject执行完,设置item_type
itemTypeAfterInterceptor.interceptor(project, dataFlow);
recordFusingMsgByGraphState(project, graphState);
log.info("【{}】requestId={}`userId={},deviceId={},executed all_nodes spends={}", project.getProjectKey(), dataFlow.getRequestId(), dataFlow.getUserId(), dataFlow.getDeviceId(), stopwatch.elapsed(TimeUnit.MILLISECONDS));
return graphState;
}
DataFlow类
/**
* 数据流对象
*/
@Setter
@Getter
@ToString(exclude = {"rwLock", "version", "filterList", "rcmdList", "userProf"})
@Slf4j
public class DataFlow {
/***必传字段***/
private final List<Long> filterList = new FilterCollection<>(); //过滤列表
private ReadWriteLock rwLock = new ReentrantReadWriteLock();//对于有竞态条件的使用锁
private AtomicLong version = new AtomicLong();//记录当前版本号
/***必传字段***/
private long userId;
private String deviceId;
private int vercode;
private String deviceType;
private int appID;
private long subAppId;
private String bzCode;//业务对应的实体entity
private String sceneCode;//用户所在场景/页面
private int pageNo;//当前页码
private int pageSize;//一页数据
private int count;//pageNumber+more_count(冗余条数,为了更好效果,额外加多的召回量)
private List<ItemScore> rcmdList = new ArrayList<>(); //推荐数据
private UserProfile userProf = new UserProfile(); //用户画像
private UserProfile njUserProf; //主播侧的用户画像
private List<ItemProfile> itemExpo; //当前曝光的ITEM(如果有)
private String abtest;//ab标识
private boolean isDag;//标记是dag图调用,只允许调用基础组(A)
private Map<String, Object> extraMap = new HashMap<>(); //扩展数据,模块间传输数据使用
private JSONObject extraJson = new JSONObject();//扩展对象,用来接收协议中extra json传参
private JSONObject resultExtraJson = new JSONObject();//扩展对象,用来接整个请求,与List<RcmdItem>同级别的扩展结果,比如直播dataversiontime
private boolean isOriginal;//是否只返回原生数据,比如首页播单接口返回播单+专题+...但是有些情况下只需要返回播单
private int itemType;
private String requestId;//请求唯一id,用于日志排查
private String hitExps;//abtest平台命中的所有实验,eg:exp1|exp2|exp3
private Map<String, String> expParamMap = new HashMap<>();//只读,实验平台获取到的配置,key=modelName
//存在并发读写,实验平台:按条件处理之后,如果模块不启用,请add进来
private Set<String> disableModelSet = Collections.newSetFromMap(new ConcurrentHashMap<String, Boolean>());
private String baseInfo = null;
public String version() {
StringBuilder sb = new StringBuilder();
sb.append(toBaseInfo())
.append("`version=").append(this.getVersion().incrementAndGet())
.append("`filterList=").append(filterList.size())
.append("`rcmdList=").append(rcmdList.size());
return sb.toString();
}
/**
* 与project配置相对应,用于每个请求,找到对应的project
*
* @return
*/
public String getProjectKey() {
String projectKey = bzCode + "_" + sceneCode;
if (abtest != null && abtest.trim().length() > 0) {
projectKey = bzCode + "_" + sceneCode + "_" + abtest;
}
return projectKey;
}
/**
* 打印基础数据
*
* @return
*/
public String toBaseInfo() {
if (baseInfo == null) {
String sb = "userid=" + userId +
"`deviceid=" + deviceId +
"`projectKey=" + getProjectKey() +
"`pageNo=" + pageNo +
"`pageSize=" + pageSize +
"`hitExps=" + hitExps +
"`requestId=" + requestId;
baseInfo = sb;
}
return baseInfo;
}
public void addExtraMap(String key, Object val) {
if (extraMap.containsKey(key)) {
log.warn("key already exist`key={}`oldVal={}`newVal={}", key, extraMap.get(key), val);
}
extraMap.put(key, val);
}
public <T> T getFromExtraMap(String key) {
Object value = extraMap.get(key);
if (value == null) {
return null;
}
return (T) value;
}
public <T> T getFromExtraMap(String key, T defaultValue) {
Object value = extraMap.get(key);
if (value == null) {
return defaultValue;
}
return (T) value;
}
/**
* bzCode_sceneCode代表一个项目,abtest代表这个项目做了ab测试
*
* @return
*/
public String getBzCodeSceneCodeKey() {
return bzCode + "_" + sceneCode;
}
public List<Long> getItemIds() {
List<Long> result = new ArrayList<>();
for (ItemScore itemScore : rcmdList) {
result.add(itemScore.getItemId());
}
return result;
}
public void setFilterList(List<Long> list) {
filterList.clear();
filterList.addAll(list);
}
/**
* DAGModel模块使用,bzCode和sceneCode和pageSize在模块中注入
* 此处只是为了构建查询条件,万不可 setRcmdList
*
* @return
*/
public DataFlow copyForDagModel() {
DataFlow dataFlow = new DataFlow();
dataFlow.setRequestId(requestId);
dataFlow.setUserId(userId);
dataFlow.setDeviceId(deviceId);
dataFlow.setVercode(vercode);
dataFlow.setDeviceType(deviceType);
dataFlow.setVercode(vercode);
dataFlow.setPageNo(pageNo);
dataFlow.setPageSize(pageSize);
dataFlow.setCount(pageSize);
dataFlow.setExtraJson(extraJson);//有可能用到这个
dataFlow.setDag(true);//标记是dag图调用
dataFlow.setItemExpo(itemExpo);
dataFlow.setExtraMap(extraMap);
dataFlow.setFilterList(new FilterCollection(filterList));
if (this.userProf != null) {
UserProfile userProfile = new UserProfile(new ConcurrentHashMap<>(this.userProf.getProperty()), new ConcurrentHashMap<>(this.userProf.getOriginalProperty()));
userProfile.setUserType(this.userProf.getUserType());
dataFlow.setUserProf(userProfile);
}
if (this.njUserProf != null) {
UserProfile njUserProfile = new UserProfile(new ConcurrentHashMap<>(this.njUserProf.getProperty()), new ConcurrentHashMap<>(this.njUserProf.getOriginalProperty()));
dataFlow.setNjUserProf(njUserProfile);
}
//分层实验参数传递
dataFlow.setHitExps(hitExps);
if (expParamMap != null) {
dataFlow.setExpParamMap(new HashMap<>(expParamMap));
}
return dataFlow;
}
/**
* 判断是否命中某个实验平台的实验
*
* @param exps
* @return
*/
public boolean containsHitExps(String... exps) {
if (StringUtils.isBlank(hitExps)) {
return false;
}
for (String exp : exps) {
if (("|" + hitExps + "|").contains("|" + exp + "|")) {
return true;
}
}
return false;
}
/**
* 获取当前用户的平台信息
*
* @return
*/
public String getUserPlatform() {
if (deviceType.startsWith("Android")) {
return PLATFORM_ANDROID;
} else if (deviceType.startsWith("IOS")) {
return PLATFORM_IOS;
} else if (deviceType.startsWith("h5-web")) {
return PLATFORM_H5;
} else {
return PLATFORM_UNKNOW;
}
}
public synchronized void addRecallProfile(String key, Object value) {
UserProfile recallProfile = getFromExtraMap("RECALL_PROFILE");
if (recallProfile == null) {
recallProfile = new UserProfile();
addExtraMap("RECALL_PROFILE", recallProfile);
}
recallProfile.setProperty(key, value);
}
public synchronized <T> T getRecallProfile(String key, T defaultValue) {
UserProfile recallProfile = getFromExtraMap("RECALL_PROFILE");
if (recallProfile == null) {
return null;
}
Object value = recallProfile.getProperty(key);
if (value == null) {
return defaultValue;
}
return (T) value;
}
public synchronized <T> T getRecallProfile(String key) {
return getRecallProfile(key, null);
}
/**
* 添加Map类型的画像
*
* @param key 画像的key
* @param value 画像的值,其中key为种子的值,value为该种子对应的分数
*/
public void addRecallProfileMap(String key, Map<String, Float> value) {
addRecallProfile(key, value);
}
/**
* 添加JsonObject类型的召回画像
*
* @param key 画像的key
* @param value 画像的值,格式:{"seed1":0.2,"seed2":0.1}
* 其中key为种子的值,value为该种子对应的分数
*/
public void addRecallProfileJSONObject(String key, JSONObject value) {
addRecallProfile(key, value);
}
/**
* 添加JsonArray类型的召回画像
*
* @param key 画像的key
* @param value 画像的值,格式:[{"key1":"value1","key2":"value2"}]
*/
public void addRecallProfileJSONArray(String key, JSONArray value) {
addRecallProfile(key, value);
}
/**
* 添加向量类型的召回画像
*
* @param key 画像的key
* @param value 画像的值,格式:[0.1,-0.2,0.3,0.4]
*/
public void addRecallProfileVector(String key, String value) {
addRecallProfile(key, value);
}
}