回溯算法基础
回溯其实都听了很多了,但是到底着怎么回事一个模模糊糊的。遇到很多问题,我们都想着要回溯,但是没搞明白,怎么回溯,回到哪里。
回溯就是一个搜索,回溯函数就是一个递归函数,在二叉树的时候,一般有递归,就会有回溯。比如后序遍历,虽然只给了root节点,但是我们是从下往上走的。
回溯法的效率
回溯法的性能如何呢,这里要和大家说清楚了,虽然回溯法很难,很不好理解,但是回溯法并不是什么高效的算法。
因为回溯的本质是穷举,穷举所有可能,然后选出我们想要的答案,如果想让回溯法高效一些,可以加一些剪枝的操作,但也改不了回溯法就是穷举的本质。
为什么用回溯,因为有些问题就只能用穷举,只能暴力搜索,比如组合 100数中选出50个数,你如果用循环,那就要用50个for,你怎么写?
回溯法解决的问题
回溯法,一般可以解决如下几种问题:
- 组合问题:N个数里面按一定规则找出k个数的集合
- 切割问题:一个字符串按一定规则有几种切割方式
- 子集问题:一个N个数的集合里有多少符合条件的子集
- 排列问题:N个数按一定规则全排列,有几种排列方式
- 棋盘问题:N皇后,解数独等等
这些问题你如果不用递归,不用回溯,你用for循环,咋写!
如何理解回溯法
回溯法解决的问题都可以抽象为树形结构,是的,我指的是所有回溯法的问题都可以抽象为树形结构!
因为回溯法解决的都是在集合中递归查找子集,集合的大小就构成了树的宽度,递归的深度就构成了树的深度。
回溯法模板
其实也是三部曲
1. 回溯函数模范返回值和参数
回溯函数一般没有返回值,直接return就可以了
参数的话,一般都是根据题意确定
2. 回溯方法中止 条件
什么时候达到了终止条件,树中就可以看出,一般来说搜到叶子节点了,也就找到了满足条件的一条答案,把这个答案存放起来,并结束本层递归。
3. 单层回溯函数遍历过程
for (选择:本层集合中元素(树中节点孩子的数量就是集合的大小)) {
处理节点;
backtracking(路径,选择列表); // 递归
回溯,撤销处理结果
}
大家可以从图中看出for循环可以理解是横向遍历,backtracking(递归)就是纵向遍历,这样就把这棵树全遍历完了,一般来说,搜索叶子节点就是找的其中一个结果了。
接下来组合相关的算法模板如下,都是有套路的。
void backtracking(参数) {
if (终止条件) {
存放结果;
return;
}
for (选择:本层集合中元素(树中节点孩子的数量就是集合的大小)) {
处理节点;
backtracking(路径,选择列表); // 递归
回溯,撤销处理结果
}
}
77 组合
给定两个整数 n 和 k,返回 1 ... n 中所有可能的 k 个数的组合。
示例: 输入: n = 4, k = 2 输出: [ [2,4], [3,4], [2,3], [1,2], [1,3], [1,4] ]
这是经典的回溯问题,也是让很多人觉得不好解决的问题。
比如拿例子中的具体,k=2,我们很容易想到用两个for循环,这样就可以了
但是如果n=100, k=50呢,那你难道用50个k循环?
而且你该怎么编程呢,k是变的!你用for写都写不出来。
解决方案就是用回溯法,用递归来解决嵌套层数的问题。每一次的递归中嵌套一个for循环,那么递归就可以用于解决多层嵌套循环的问题了。
如果这个题抽象为一个数,如下图,就好理解了。图片来源于代码随想录。
n 相当于数的宽度,k相当于数的深度!图中每次搜索到了叶子节点,我们就找到了一个结果。
那么我就就可以利用模板来解决这个问题了
回溯三部曲
1. 递归参数的返回值以及参数
返回值没有,我们定义全局变量来保存结果
参数就有 n 遍历的数字 k 多少个数 startIndex 起始索引,为了让遍历的数的区间进行移动
path 保存符合结果的临时变量, result 保存的结果
startIndex是个很重要的参数,因此这可以让你遍历不重复,像树结构那样,当你遍历到1,那么就不会再回到1
2. 递归函数的中止条件
到path的大小 已经等于K的时候,说明就可以返回了。
3. 单层递归
for循环每次从startIndex开始遍历,然后用path保存取到的节点i。
可以看出backtracking(递归函数)通过不断调用自己一直往深处遍历,总会遇到叶子节点,遇到了叶子节点就要返回。
最后你的回溯,将这个结果删除
代码如下
class Solution:
def combine(self, n: int, k:int):
result = []
self.backtracking(n, k, 1, [], result)
return result
def backtracking(self, n, k, startIndex, path, result):
if len(path) == k:
result.append(path)
return
for i in range(startIndex, n):
path.append(i)
self.backtracking(n, k, i+1, path, result)
path.pop()
组合的剪枝
其实这个排列可以剪枝的,如下图 假设n=4, k=4 其实刚开始的遍历不需要那么多
这里遍历的就是遍历的总次数,代码需要再for 那里改进
当 我们遍历的只是满足条件的,即 [1, n - (k- len(path)) + 2)
假设path为0,n=4, k=2
我们可以遍历的
1开头
2开头
3开头
4开头就不行了
优化的代码如下
class Solution:
def combine(self, n: int, k:int):
result = []
self.backtracking(n, k, 1, [], result)
return result
def backtracking(self, n, k, startIndex, path, result):
if len(path) == k:
result.append(path)
return
for i in range(startIndex, n - (k-len(path)) + 2):
path.append(i)
self.backtracking(n, k, i+1, path, result)
path.pop()
216 组合总和III
找出所有相加之和为 n 的 k 个数的组合。组合中只允许含有 1 - 9 的正整数,并且每种组合中不存在重复的数字。
说明:
- 所有数字都是正整数。
- 解集不能包含重复的组合。
示例 1: 输入: k = 3, n = 7 输出: [[1,2,4]]
示例 2: 输入: k = 3, n = 9 输出: [[1,2,6], [1,3,5], [2,3,4]]
这个题目与组合非常相似,但是有以下不同。
1. 判断条件,这里要判断个数
2 即使个数对了,还要判断总和是否OK
3. 如果想要剪枝,条件也有不同,个数满足,累计和大于的时候,就该返回了,
另外关于宽度的剪枝,也是同样的操作
class Solution:
def combinationSum3(self, k: int, n: int) -> List[List[int]]:
path = []
result = []
self.backtracking(k, n, 0, path, result)
return result
def backtracking(self, k, n, startIndex, path, result):
if sum(path) > n:
return
if len(path) == k:
if sum(path) == n:
result.append(path)
return
for i in range(startIndex, 9-(k-len(path)) + 2):
path.append(i)
self.backtracking(k, n, i+1, path, result)
path.pop()
17 电话号码的字母组合
这个题其实也是组合的一个变种,但是,稍微加了一点难度。
我们之前一直在遍历数字,横向上,但是,这次是遍历字母,看似没有什么区别
但是,我们并不是遍历26个字母,而是根据出现的数字所映射的字母去遍历。
而且,不是所谓的你出现2个数字,对应6个字母,我们的宽度就是6个字母。
而是,当深度每家一层,我们遍历的字母是不一样的。比如你有2个数字,2,3
分别对应字母 abc , def
你一层的宽度的遍历的字母是abc
第二层遍历的字母是def,示意图如下
因此,在遍历模板的for 循环中,startIndex的作为就是在深度加深的时候,去匹配不同的字符
这下,我们来构造回溯三部曲
1. 递归参数和返回值
递归参数其实是有我们digital的长度,map_list, startIndex, path, result
其中maplist 对应的字母表,map_list[start_index] 就对应字母list
2. 递归中止条件
path的长度等于字符串的长度就返回
3. 单层递归
其实是一样,只不过for循环需要改
还有就是递归的深度,我们startindex最大要与digitial_length对应
class Solution:
def letterCombinations(self, digits: str) -> List[str]:
if digits == "":
return []
dic = {
"2": ["a", "b", "c"],
"3": ["d", "e", "f"],
"4": ["g", "h", "i"],
"5": ["j", "k", "l"],
"6": ["m", "n", "o"],
"7": ["p", "q", "r", "s"],
"8": ["t", "u", "v"],
"9": ["w", "x", "y", "z"]
}
map_list = []
# 这里是假设每个数字都不同,你也可以去重
# digits = set(list(digits))
length = len(digits)
for i in digits:
map_list.append(dic[i])
path= = []
result = []
self.backtracking(length, 0, map_list, path, result)
return result
def backtracking(self, length, startIndex, maplist, path, result):
if len(path) == length:
result.append("".join(path))
return
if startIndex <= length-1:
for item in maplist[startIndex]:
path.append(item)
self.backtracking(length, startIndex+1, path, result)
path.pop()