百度百科-递归:
递归做为一种算法在程序设计语言中广泛应用。 一个过程或函数在其定义或说明中有直接或间接调用自身的一种方法,它通常把一个大型复杂的问题层层转化为一个与原问题相似的规模较小的问题来求解,递归策略只需少量的程序就可描述出解题过程所需要的多次重复计算,大大地减少了程序的代码量。递归需要有边界条件、递归前进段和递归返回段
正文
最近写了一个比较难的功能,涉及到递归,之前不太会用递归,但是现在有涉及到一些重复业务的逻辑,会用递归解决问题。简单几行代码就可以解决问题。我认为递归需要稍微复杂的思考逻辑以及基础的技术功底。所以我在这儿由简入繁地总结一些我平时用到的递归算法情况:
求数组的最大值
需求:求数组[2,4,6,2,1,5,2,8,9,4,3,0]的最大值
思路:先设置初始值,之后逐个遍历判断,求出最大值
代码:
package com.example.demo;
import lombok.Data;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@Data
class Number{
private Integer max;
}
public class CircleDemo {
private static final Logger logger = LoggerFactory.getLogger(CircleDemo.class);
public void getMaxNum() {
Integer[] ss = new Integer[]{2, 4, 6, 2, 1, 5, 2, 8, 9, 4, 3, 0};
Number number = new Number();
//设置一个初始值
number.setMax(ss[0]);
getMaxNumCircle(ss, number, 1);
logger.info("max====>[{}]", number.getMax());
}
/**
* 取最大值[2,4,6,2,1,5,2,8,9,4,3,0]
*
* @param ss
* @param number
*/
public void getMaxNumCircle(Integer[] ss, Number number, int index) {
//Math.max比较值大小
number.setMax(Math.max(ss[index], number.getMax()));
//定义递归边界值
if (index < ss.length - 1) {
index++;
//再次递归
getMaxNumCircle(ss, number, index);
}
}
public static void main(String[] args) {
CircleDemo circleDemo = new CircleDemo();
circleDemo.getMaxNum();
}
}
结果:
给树形赋值
需求:有个树形,需要从下向上计算人数,父级节点的人数等于儿子节点叠加人数之和。
结构以及数据如图。
通过销售人员获取每一组、每一部门,每一个系统的销售量。首先获取销售人员的销售量,之后向上递归计算出每一级的销量。代码如下
public List<Tree> getAll() {
return treeMapper.getAll();
}
public void initTreeNumber() {
List<Tree> trees = getAll();
List<Tree> employees = trees.stream()
.filter(tree -> Objects.equals(tree.getType(), 3))
.collect(Collectors.toList());
calculate(employees, trees);
//更新
treeMapper.updateBatch(trees);
}
public void calculate(List<Tree> employees, List<Tree> trees) {
//分组
Map<String, List<Tree>> employeesByParentId = employees.stream()
.collect(Collectors.groupingBy(Tree::getParentId));
Set<String> parentIds = employeesByParentId.keySet();
//结束递归循环的标志量
if(parentIds.isEmpty()){
return;
}
List<Tree> parentTrees = new ArrayList<>();
for (String parentId : parentIds) {
List<Tree> treesInParentId = employeesByParentId.get(parentId);
int sum = treesInParentId.stream().mapToInt(Tree::getNumber).sum();
for (Tree tree : trees) {
if (Objects.equals(parentId, tree.getId())) {
tree.setNumber(sum);
parentTrees.add(tree);
}
}
}
//向上计算
calculate(parentTrees, trees);
}
结果:
路线图求最大人数
需求:并联求和,串联求最大
路线图的情况涉及的比较多,这里我举一个包含情况比较多的路线图
上面图包含了以下几种情况:
- 单纯的串联
- 单纯的并联
- 并联之中的分支包含串联
- 无连线的节点(也就是其他分支)
要求是:并联相加,串联取最大,无连线当做并联处理。根据计算规则,上面图的最终值为24
我的思路是先分成两个方法,一个是处理并联方法,另外一个是处理串联方法。首先获取首节点列表(不包含连线的节点当做首节点处理),之后循环遍历每个首节点,每个首节点的处理都是调用串联分支,若是遇到并联分支则调用并联方法。这样说比较晦涩难懂,直接上代码:
数据库建表语句:
create table nodes (
id varchar(36) not null comment'主键',
name varchar(50) not null comment '节点名称',
number int not null comment '数量',
PRIMARY key(id)
);
create table edges(
id varchar(36) not null comment '主键',
source varchar(36) not null comment '源',
target varchar(36) not null comment '目标',
PRIMARY key (source, target)
)
数据信息:
代码:
package com.myproject.demo.dbtest.service;
import com.myproject.demo.dbtest.mapper.EdgeMapper;
import com.myproject.demo.dbtest.mapper.NodeMapper;
import com.myproject.demo.dbtest.util.SystemException;
import com.myproject.demo.dbtest.vo.Edge;
import com.myproject.demo.dbtest.vo.Node;
import com.myproject.demo.dbtest.vo.NumberVO;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.stream.Collectors;
@Service
public class NodeService {
@Autowired
NodeMapper nodeMapper;
@Autowired
EdgeMapper edgeMapper;
/**
* 计算路线图值
*
* @return
*/
public int calculateNumber() {
List<Node> nodes = nodeMapper.getAll();
List<String> nodeIds = nodes.stream().map(Node::getId).collect(Collectors.toList());
List<Edge> edges = edgeMapper.getEdgesByNodeIds(nodeIds);
//定义接收最大值的实体类
NumberVO numberVO = new NumberVO(0);
//查找首节点
List<String> firstNodeIds = getFirstNodeIds(nodes, edges);
parallel(firstNodeIds, nodes, edges, numberVO);
return numberVO.getMaxNum();
}
/**
* 并联计算
*
* @param nodeIds
* @param nodes
* @param edges
* @param numberVO
*/
private void parallel(List<String> nodeIds,
List<Node> nodes,
List<Edge> edges,
NumberVO numberVO) {
List<NumberVO> maxNums = new ArrayList<>();
List<NumberVO> parallelEnds = new ArrayList<>();
for (String nodeId : nodeIds) {
NumberVO maxNum = new NumberVO(0);
//每个分支都当做串联分支处理,调用串联方法
circle(nodeId, nodes, edges, maxNum);
if (Strings.isNotEmpty(maxNum.getParallelEndId())) {
//当前并联分支到了某个结束节点时,找到当前并联分支的最终节点,处理并联结果得到最大值
parallelEnds.add(maxNum);
} else {
//串联分支完成遍历
maxNums.add(maxNum);
}
}
//找到当前并联分支的最终节点,处理并联结果得到最大值
if (!parallelEnds.isEmpty()) {
dealParallelChilds(nodes, edges, numberVO, maxNums, parallelEnds);
} else {
numberVO.setMaxNum(maxNums.stream().mapToInt(NumberVO::getMaxNum).sum());
}
}
/**
* 处理并联中的并联
*
* @param nodes
* @param edges
* @param numberVO
* @param maxNums
* @param parallelEnds
*/
private void dealParallelChilds(List<Node> nodes, List<Edge> edges, NumberVO numberVO, List<NumberVO> maxNums, List<NumberVO> parallelEnds) {
//分组
Map<String, List<NumberVO>> parallelEndGroupByEndId = parallelEnds.stream()
.collect(Collectors.groupingBy(NumberVO::getParallelEndId));
Set<String> endIds = parallelEndGroupByEndId.keySet();
//获取多个并联终点的最后一个节点
String parallelEndIds = parallelEnds.get(0).getParallelEndIds();
List<NumberVO> indexEndNumberVOs = getSortEndNumberVOs(endIds, parallelEndIds);
String endIdFinal = indexEndNumberVOs.get(indexEndNumberVOs.size() - 1).getParallelEndId();
NumberVO sumParam = new NumberVO(0);
for (NumberVO indexEndNumberVO : indexEndNumberVOs) {
String endId = indexEndNumberVO.getParallelEndId();
//获取当前结束节点对应的并联节点的总和
List<NumberVO> sourceNodeParams = parallelEndGroupByEndId.get(endId);
int sum = sourceNodeParams.stream().mapToInt(NumberVO::getMaxNum).sum();
if (!Objects.equals(endId, endIdFinal)) {
//要和上个节点做比较,得出最大值
NumberVO maxParam = new NumberVO(sum);
//比较当前求和结果和最终节点大小
Node endNode = getNode(nodes, endId);
maxParam.setMaxNum(Math.max(sum, endNode.getNumber()));
//处理下个节点
List<String> nextIds = getTargetIdsByEdges(edges, endId);
if (nextIds.size() == 1) {
//串联求取最大值,如果这里结束了,则代表已经走完,得到最大值
circle(nextIds.get(0), nodes, edges, maxParam);
if (Objects.equals(maxParam.getParallelEndId(), endIdFinal)) {
sumParam.setMaxNum(maxParam.getMaxNum() + sumParam.getMaxNum());
continue;
}
if (null != maxParam.getParallelTargetIds()
&& !maxParam.getParallelTargetIds().isEmpty()) {
//新的并联关系
NumberVO pparam = new NumberVO(0);
parallel(maxParam.getParallelTargetIds(), nodes, edges, pparam);
}
} else if (nextIds.size() > 1) {
//获取下次并联结果集
NumberVO sumParam1 = new NumberVO(0);
parallel(nextIds, nodes, edges, sumParam1);
}
} else {
sumParam.setMaxNum(sum + sumParam.getMaxNum());
}
}
maxNums.add(sumParam);
Node endFinalNode = getNode(nodes, endIdFinal);
int sumAll = maxNums.stream().mapToInt(NumberVO::getMaxNum).sum();
numberVO.setMaxNum(Math.max(sumAll, endFinalNode.getNumber()));
List<String> nextIds = getTargetIdsByEdges(edges, endIdFinal);
if (nextIds.size() == 1) {
numberVO.setParallelEndIdFinal(nextIds.get(0));
} else if (nextIds.size() > 1) {
//获取下次并联结果集
NumberVO sumParam1 = new NumberVO(0);
parallel(nextIds, nodes, edges, sumParam1);
}
}
/**
* 获取并联终点的排序(可能存在并联中的并联,所以有多个终点)
*
* @param endIds
* @param parallelEndIds
* @return
*/
private List<NumberVO> getSortEndNumberVOs(Set<String> endIds,
String parallelEndIds) {
List<NumberVO> indexEndNumberVOs = new ArrayList<>();
for (String endId : endIds) {
NumberVO indexParam = new NumberVO();
//size=1时代表当前的endId集合中包含最终节点
int index = parallelEndIds.indexOf(endId);
indexParam.setEndIndex(index);
indexParam.setParallelEndId(endId);
indexEndNumberVOs.add(indexParam);
}
//排序得到最大值
indexEndNumberVOs.sort(Comparator.comparingInt(NumberVO::getEndIndex));
return indexEndNumberVOs;
}
/**
* 获取当前节点
*
* @param nodes
* @param currentNodeId
* @return
*/
private Node getNode(List<Node> nodes, String currentNodeId) {
Optional<Node> optional = nodes.stream()
.filter(node -> Objects.equals(node.getId(), currentNodeId))
.findFirst();
if (!optional.isPresent()) {
throw new SystemException("系统不存在当前节点信息");
}
return optional.get();
}
/**
* 比较当前最大值
*
* @param currentNodeId 当前节点
* @param nodes 所有节点
* @param edges 所有连线
* @param maxNum 最大值
*/
private void circle(String currentNodeId,
List<Node> nodes,
List<Edge> edges,
NumberVO maxNum) {
//判断当前节点是否为并联终止节点
List<String> sourceIds = getSourceIdsByEdges(edges, currentNodeId);
if (sourceIds.size() > 1) {
maxNum.setParallelEndId(currentNodeId);
maxNum.setParallelEndIds(currentNodeId);
getEndId(currentNodeId, edges, maxNum);
return;
}
//获取当前分支信息
Optional<Node> currentNodes = nodes.stream().filter(node -> Objects.equals(node.getId(), currentNodeId)).findFirst();
if (!currentNodes.isPresent()) {
throw new SystemException("当前节点不存在");
}
Node currentNode = currentNodes.get();
maxNum.setMaxNum(Math.max(currentNode.getNumber(), maxNum.getMaxNum()));
//判断下个节点是否为并联节点
List<String> nextIds = getTargetIdsByEdges(edges, currentNodeId);
if (nextIds.size() == 1) {
//串联
circle(nextIds.get(0), nodes, edges, maxNum);
} else if (nextIds.size() > 1) {
//并联
maxNum.setParallelTargetIds(nextIds);
NumberVO parallelNumber = new NumberVO(0);
parallel(nextIds, nodes, edges, parallelNumber);
maxNum.setMaxNum(parallelNumber.getMaxNum());
if (Strings.isNotEmpty(parallelNumber.getParallelEndIdFinal())) {
circle(parallelNumber.getParallelEndIdFinal(), nodes, edges, maxNum);
}
}
}
/**
* 循环遍历得到结果集
*
* @param nodeId
* @param edges
* @param numberVO
*/
private void getEndId(String nodeId, List<Edge> edges, NumberVO numberVO) {
List<String> sourceIds = getSourceIdsByEdges(edges, nodeId);
if (!sourceIds.isEmpty()) {
numberVO.setParallelEndIds(numberVO.getParallelEndIds() + "," + nodeId);
List<String> nextIds = getTargetIdsByEdges(edges, nodeId);
if (!nextIds.isEmpty()) {
getEndId(nextIds.get(0), edges, numberVO);
}
}
}
/**
* 获取当前节点的上个节点列表
*
* @param allEdgePOs
* @param targetId
* @return
*/
private List<String> getSourceIdsByEdges(List<Edge> allEdgePOs, String targetId) {
return allEdgePOs.stream()
.filter(allEdgePO -> Objects.equals(targetId, allEdgePO.getTarget()))
.map(Edge::getSource)
.collect(Collectors.toList());
}
/**
* 获取当前节点的下个节点列表
*
* @param allEdgePOs
* @param sourceId
* @return
*/
private List<String> getTargetIdsByEdges(List<Edge> allEdgePOs, String sourceId) {
return allEdgePOs.stream()
.filter(allEdgePO -> Objects.equals(sourceId, allEdgePO.getSource()))
.map(Edge::getTarget)
.collect(Collectors.toList());
}
/**
* 获取首节点列表
*
* @param nodes
* @param edges
* @return
*/
private List<String> getFirstNodeIds(List<Node> nodes, List<Edge> edges) {
List<String> firstNodeIds = new ArrayList<>();
for (Edge edge : edges) {
boolean isExists = edges.stream()
.anyMatch(edge1 -> Objects.equals(edge.getSource(), edge1.getTarget()));
if (!isExists) {
firstNodeIds.add(edge.getSource());
}
}
for (Node node : nodes) {
String nodeId = node.getId();
boolean isExists = edges.stream()
.anyMatch(edgePO -> (Objects.equals(edgePO.getSource(), nodeId) || Objects.equals(edgePO.getTarget(), nodeId)));
if (!isExists) {
firstNodeIds.add(nodeId);
}
}
return firstNodeIds.stream().distinct().collect(Collectors.toList());
}
}
打印结果,我采用的是使用controller接口调用的:
package com.myproject.demo.dbtest.controller;
import com.myproject.demo.dbtest.service.NodeService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@RequestMapping("/nodes")
@RestController
public class NodeController {
@Autowired
NodeService nodeService;
@GetMapping
public int getMaxNumber() {
return nodeService.calculateNumber();
}
}
逻辑比较复杂,递归的过程中涉及到调用其他递归方法,但是这个要先想好思路,因为肯定是要从头开始计算,所以我先处理并联关系,之后再将并联分支当做串联处理。我技术渣渣,这个方法我想了两天一晚上才想出来,但是这个复杂的逻辑直接看是肯定看不懂的,所以要打debugger跟一下,实体类根据数据库建表语句创建就可以。如果有疑问或者好的意见,欢迎留言!
注意:在传值过程中,还要注意是引用传递还是值传递,如果采用的值传递,那么最终结果还是0,没有变化。
代码示例github地址:https://github.com/xueying123-cat/mybatis-exercise.git