回溯章节理论基础:
77. 组合
题目链接:https://leetcode.cn/problems/combinations/
思路:
直接的解法当然是使用for循环,例如示例中k为2,很容易想到 用两个for循环,这样就可以输出 和示例中一样的结果。
如果n为100,k为50呢,那就50层for循环,这肯定不可行。虽然想暴力搜索,但是用for循环嵌套连暴力都写不出来。
那么我们这里就要用到回溯法了,用递归来做层叠嵌套,回溯法解决的问题都可以抽象为树形结构(N叉树),用树形结构来理解回溯就容易多了。
可以看出这棵树,一开始集合是 1,2,3,4, 从左向右取数,取过的数,不再重复取。
第一次取1,集合变为2,3,4 ,因为k为2,我们只需要再取一个数就可以了,分别取2,3,4,得到集合[1,2] [1,3] [1,4],以此类推。
图中可以发现n相当于树的宽度,k相当于树的深度。
每次搜索到了叶子节点,我们就找到了一个结果。相当于只需要把达到叶子节点的结果收集起来,就可以求得 n个数中k个数的组合集合。
那么,我们可以开始定义函数了。函数里一定有两个参数,既然是集合n里面取k个数,那么n和k是两个int型的参数。
然后还需要一个参数,为int型变量startIndex,这个参数用来记录本层递归中,集合从哪里开始遍历,不可能每次都是从1开始,1完了之后就是2,3,4这样依次下去。
所以需要startIndex来记录下一层递归,搜索的起始位置。
终止情况呢,就是path这个数组的大小(我这里名称用的是list)如果达到k,说明我们找到了一个子集大小为k的组合了,在图中path存的就是根节点到叶子节点的路径。此时用result二维数组,把path保存起来,并终止本层递归。
class Solution {
List<List<Integer>> res = new ArrayList<>();
List<Integer> list = new ArrayList<>();
public List<List<Integer>> combine(int n, int k) {
backtracking(n,k,1);
return res;
}
public void backtracking(int n, int k, int startIndex){
if(list.size() == k){
res.add(new ArrayList<>(list));
return ;
}
// 1,2,3,4
for(int i=startIndex;i<=n;i++){
list.add(i);
// 对于1,把2,3,4的情况考虑进来
backtracking(n, k, i+1);
// 回溯,把2,3,4,1值依次删去
list.removeLast();
}
}
}
时间复杂度: O(n * 2^n)
空间复杂度: O(n)
剪枝优化:
回溯法虽然是暴力搜索,但也有时候可以有点剪枝优化一下的,这个遍历的范围是可以剪枝优化的。
来举一个例子,n = 4,k = 4的话,那么第一层for循环的时候,从元素2开始的遍历都没有意义了。 在第二层for循环,从元素3开始的遍历都没有意义了。
所以,可以剪枝的地方就在递归中每一层的for循环所选择的起始位置。
如果for循环选择的起始位置之后的元素个数 已经不足 我们需要的元素个数了,那么就没有必要搜索了。
接下来看一下优化过程如下:
已经选择的元素个数:path.size();
还需要的元素个数为: k - path.size();
在集合n中至多要从该起始位置 : n - (k - path.size()) + 1,开始遍历
为什么有个+1呢,因为包括起始位置,我们要是一个闭合的区间。
这里可以举个例子更好理解,n = 4,k = 3, 目前已经选取的元素为0(path.size为0),n - (k - 0) + 1 即 4 - ( 3 - 0) + 1 = 2。
从2开始搜索都是合理的,可以是组合[2, 3, 4]。
class Solution {
List<List<Integer>> res = new ArrayList<>();
List<Integer> list = new ArrayList<>();
public List<List<Integer>> combine(int n, int k) {
backtracking(n,k,1);
return res;
}
public void backtracking(int n, int k, int startIndex){
if(list.size() == k){
res.add(new ArrayList<>(list));
return ;
}
// 1,2,3,4
// 优化:剪枝操作 k-list.size()是还剩下多少个元素
// 例如 n=4 k=3 就会到2为止 2,3,4
for(int i=startIndex;i<=n-(k-list.size())+1;i++){
list.add(i);
// 对于1,把2,3,4的情况考虑进来
backtracking(n, k, i+1);
// 回溯,把2,3,4,1值依次删去
list.removeLast();
}
}
}