一、回溯法
回溯法可以看成穷举法的升级版。回溯法非常适合解决由多个步骤组成,并且每个步骤都有多个选项的问题。 当我们在某一步选择了其中一个选项后,就进入下一步,然后又面临新的选项。就这样重复选择,直至最终的状态。
可以用树结构(解空间树)来表示用回溯法解决的问题的所有选项。在某一步有 n n n 个可能的选项,那么该步骤就可以看成是树结构中的一个节点,每个选项看成树中节点的连接线,经过这些连接线到达该节点的 n n n 个子节点。叶节点则对应着最终的状态:
- 如果在叶节点的状态满足题目的约束条件,那我们就找到了一个问题的解;
- 如果在叶节点的状态不满足约束条件,那就回溯到它的上一个节点尝试其它的选项;
- 如果上一个节点所有可能的选项都已经试过,则再回溯到它的上一个节点,直至所有节点的所有选项都尝试过;
- 如果所有节点的所有选项都已经尝试过却没有找到满足约束条件的最终状态,说明该问题无解。
所以,用回溯法求解问题的过程,就是搜索整个解空间树的过程。当搜索至任一节点时,先判断该节点是否包含问题的解,如果不包含,则跳过该节点的子节点的搜索,避免无效搜索,这个过程叫作 “剪枝”。
Tips:
回溯法求解的特点是在搜索过程中动态产生问题的解空间。在任意时刻,算法只保存从根结点到当前结点的路径。
二、案例
案例一:数字的全排列(没有重复数字)
题目:
给定一个没有重复数字的序列,返回其所有可能的全排列。
示例:
输入: [1,2,3]
输出:
[
[1,2,3],
[1,3,2],
[2,1,3],
[2,3,1],
[3,1,2],
[3,2,1]
]
分析:
深度优先搜索(DFS) + 回溯:
“回溯”可以理解为“状态重置”,就是回到上一步的状态。通常,我们要解决的问题是在一棵树上完成的,在这棵树上搜索需要的答案,一般使用深度优先搜索。
以数组 [1, 2, 3] 为例进行分析。
我们自己手写的话,只需要按顺序枚举每一位可能出现的情况,已经选择过的数字不能再使用,即
- 在枚举第一位的时候,有三种情况;
- 在枚举第二位的时候,前面已经用过的数字不能再用,此时只有两种情况;
- 在枚举第三位的时候,前面两个已经用过的数字不能再用,此时只有一种情况。
这样就能做到不重不漏,把可能的全排列都枚举出来,结果如下图所示:
如上图,将所有的结果写出来,用一个树形结构表示,就是这个问题的解空间树。每执行一次深度优先搜索,从树的根节点到叶节点的路径就是一个全排列。使用回溯法搜索全排列的解的过程如下图:
说明:
- 每一个节点表示全排列问题求解的不同阶段,也叫状态;
- 深度优先搜索的过程中,在搜索至叶节点时需要回溯,即返回上一步继续求解,这时的状态变量需要设置成为和先前一样;
- 因此,在回溯到上一层节点时,需要撤销上一次选择,即“状态重置”。
- 上图中,在 [1, 2] 处撤销 2 回到了 [1],重新选择了 3 到达了状态 [1, 3]。而在状态 [1, 3] 处,因为上一步撤销了 2,所以在这里才可以选择 2 从而达到状态 [1, 3, 2],这就是状态重置的意义。
参考代码:
class Solution {
public List<List<Integer>> permute(int[] nums) {
List<List<Integer>> list = new ArrayList<>();
backTrack(list, new ArrayList<Integer>(), nums);
return list;
}
public void backTrack(List<List<Integer>> list, List<Integer> tempList, int[] nums){
if(tempList.size() == nums.length){
list.add(new ArrayList<>(tempList));
}else{
for(int i = 0; i < nums.length; i++){
// 已经使用过的数字就不能再用了,continue跳出本次循环,进行下一次循环
if(tempList.contains(nums[i]))
continue;
tempList.add(nums[i]);
backTrack(list, tempList, nums);
// 回溯
tempList.remove(tempList.size() - 1);
}
}
}
}
案例二:数字的全排列(有重复数字)
题目:
给定一个可包含重复数字的序列,返回所有不重复的全排列。
示例:
输入: [1,1,2]
输出:
[
[1,1,2],
[1,2,1],
[2,1,1]
]
分析:
和上一题相比,这道题多了一个条件:有重复数字。我们可以用上一题的解法来做,然后去除掉结果中重复的全排列即可。
注意:
- 上一题没有重复数字,我们可以用 contains(nums[i]) 方法判断 nums[i] 是否被用过。对于有重复数字的情况,我们就要创建一个布尔型的数组来标记每一位数字在一次全排列过程中是否被用过,而不能再使用 contains() 方法。
思路一:利用 Set
用 Set 存储全排列的结果,利用 Set 不存储重复元素的特性去除重复的全排列。
参考代码:
class Solution {
public List<List<Integer>> permuteUnique(int[] nums) {
Set<List<Integer>> set = new HashSet<>();
boolean[] used = new boolean[nums.length];
backTrack(set, new ArrayList<Integer>(), nums, used);
return new ArrayList<>(set);
}
public void backTrack(Set<List<Integer>> set, List<Integer> tempList, int[] nums, boolean[] used){
if(tempList.size() == nums.length){
set.add(new ArrayList<>(tempList));
}else{
for(int i = 0; i < nums.length; i++){
if(used[i])
continue;
used[i] = true;
tempList.add(nums[i]);
backTrack(set, tempList, nums, used);
used[i] = false;
// 回溯
tempList.remove(tempList.size() - 1);
}
}
}
}
思路二:“剪枝”
设计剪枝函数:
- 先对给定的序列排序,使得重复元素位于相邻的位置上;
- 每次搜索前先判断 nums[i] 和 nums[i-1] 是否相等,若相等,说明重复,则跳过本次循环。
参考代码:
class Solution {
public List<List<Integer>> permuteUnique(int[] nums) {
Arrays.sort(nums);// 对数组排序,使得重复元素位于相邻的位置
List<List<Integer>> list = new ArrayList<>();
boolean[] used = new boolean[nums.length];
backTrack(list, new ArrayList<Integer>(), nums, used);
return list;
}
public void backTrack(List<List<Integer>> list, List<Integer> tempList, int[] nums, boolean[] used){
if(tempList.size() == nums.length){
list.add(new ArrayList<>(tempList));
}else{
for(int i = 0; i < nums.length; i++){
// used[i] 用过了就跳出本次循环
if(used[i])
continue;
// 剪枝函数
if(i > 0 && nums[i] == nums[i - 1] && !used[i - 1])
continue;
used[i] = true;
tempList.add(nums[i]);
backTrack(list, tempList, nums, used);
used[i] = false;
// 回溯
tempList.remove(tempList.size() - 1);
}
}
}
}
案例三:矩阵中的路径
题目:
分析:
回溯法:
- 创建一个和字符矩阵大小相同的布尔型矩阵,记录矩阵中的位置是否走过,走过为 true,没走过则为false;
- 从 (0, 0) 开始,遍历字符矩阵,先找到字符串的第一个字符在字符矩阵中的位置;
- 从该位置向左、右、上、下四个方向探索,匹配字符串的下一个字符;
- 若全部匹配成功,返回 true;否则,回溯至上一步,继续探索,直至遍历完整个字符数组;
- 探索过程中,要注意边界条件,行和列不能越界。
参考代码:
public class Solution {
public boolean hasPath(char[] matrix, int rows, int cols, char[] str){
boolean[] flag = new boolean[matrix.length];
for(int i = 0; i < rows; i++){
for(int j = 0; j < cols; j++){
if(backTrack(matrix, i, j, 0, rows, cols, str, flag))
return true;
}
}
return false;
}
/**
i:行坐标
j:列坐标
k:字符串的索引,表示当前匹配到第 k 位字符
rows:行数
cols:列数
str:要匹配的字符串
flag:标志矩阵
*/
public boolean backTrack(char[] matrix, int i, int j, int k, int rows, int cols, char[] str, boolean[] flag){
// 计算匹配到的第一个字符在矩阵中的位置
int index = i * cols + j;
// 递归结束条件
if(i < 0 || i >= rows || j < 0 || j >= cols || matrix[index] != str[k] || flag[index])
return false;
// 当 k = str.length - 1 时,说明整个字符串都匹配到了
if(k == str.length - 1)
return true;
flag[index] = true;
// 向上下左右四个方向搜索,匹配下一个字符
if(backTrack(matrix, i - 1, j, k + 1, rows, cols, str, flag) || // 向上
backTrack(matrix, i + 1, j, k + 1, rows, cols, str, flag) || // 向下
backTrack(matrix, i, j - 1, k + 1, rows, cols, str, flag) || // 向左
backTrack(matrix, i, j + 1, k + 1, rows, cols, str, flag)){ // 向右
return true;
}
// 没有匹配到,回溯
flag[index] = false;
return false;
}
}