回溯三部曲
- 递归函数的返回值以及参数
- 回溯函数终止条件
- 单层搜索的过程
void backtracking(参数) {
if (终止条件) {
存放结果;
return;
}
for (选择:本层集合中元素(树中节点孩子的数量就是集合的大小)) {
处理节点;
backtracking(路径,选择列表); // 递归
回溯,撤销处理结果
}
}
组合
回溯法解决的问题都可以抽象为树形结构(N叉树),用树形结构来理解回溯就容易多了。
可以看出这棵树,一开始集合是 1,2,3,4, 从左向右取数,取过的数,不再重复取。
第一次取1,集合变为2,3,4 ,因为k为2,我们只需要再取一个数就可以了,分别取2,3,4,得到集合[1,2] [1,3] [1,4],以此类推。
每次从集合中选取元素,可选择的范围随着选择的进行而收缩,调整可选择的范围。
图中可以发现n相当于树的宽度,k相当于树的深度。
那么如何在这个树上遍历,然后收集到我们要的结果集呢?
图中每次搜索到了叶子节点,我们就找到了一个结果。
相当于只需要把达到叶子节点的结果收集起来,就可以求得 n个数中k个数的组合集合。
class Solution {
public:
//回溯法
vector<int> path;
vector<vector<int>> result;
void backtracing(int n, int k, int startIndex) {
//确定终止条件
if (path.size() == k) {
result.push_back(path);
return;
}
//一次循环
for (int i = startIndex; i <= n; i++) {
path.push_back(i);
backtracing(n, k, i + 1);
path.pop_back();
}
return;
}
vector<vector<int>> combine(int n, int k) {
backtracing(n, k, 1);
return result;
}
};
优化
优化过程如下:
已经选择的元素个数:path.size();
所需需要的元素个数为: k - path.size();
列表中剩余元素(n-i) >= 所需需要的元素个数(k - path.size())
在集合n中至多要从该起始位置 : i <= n - (k - path.size()) + 1,开始遍历
为什么有个+1呢,因为包括起始位置,我们要是一个左闭的集合。
举个例子,n = 4,k = 3, 目前已经选取的元素为0(path.size为0),n - (k - 0) + 1 即 4 - ( 3 - 0) + 1 = 2。
class Solution {
public:
//回溯法
vector<int> path;
vector<vector<int>> result;
void backtracing(int n, int k, int startIndex) {
//确定终止条件
if (path.size() == k) {
result.push_back(path);
return;
}
//一次循环
for (int i = startIndex; i < n-(k-path.size())+1; i++) {
path.push_back(i);
backtracing(n, k, i + 1);
path.pop_back();
}
return;
}
vector<vector<int>> combine(int n, int k) {
backtracing(n, k, 1);
return result;
}
};
216.组合总和
本题在前一题的基础上增加了判断组合的总和是否为给定的数,并不难
class Solution {
public:
vector<int> path;
vector<vector<int>> result, resultsum;
//求组合
void backtracing(int k, int startIndex) {
if (path.size() == k) {
result.push_back(path);
return;
}
for (int i = startIndex; i <= 9 - (k - path.size()) + 1; i++) {
path.push_back(i);
backtracing(k, i + 1);
path.pop_back();
}
}
vector<vector<int>> combinationSum3(int k, int n) {
int sum;
backtracing(k, 1);
for (int i = 0; i < result.size(); i++) {
sum = 0;
for (int j = 0; j < k; j++) {
sum = sum + result[i][j];
}
if (sum == n)resultsum.push_back(result[i]);
}
return resultsum;
}
};