基础理论
回溯算法实际上一个类似枚举的搜索尝试过程,主要是在搜索尝试过程中寻找问题的解,当发现已不满足求解条件时,就“回溯”返回,尝试别的路径。回溯法是一种选优搜索法,按选优条件向前搜索,以达到目标。但当探索到某一步时,发现原先选择并不优或达不到目标,就退回一步重新选择,这种走不通就退回再走的技术为回溯法,而满足回溯条件的某个状态的点称为“回溯点”。许多复杂的,规模较大的问题都可以使用回溯法,有“通用解题方法”的美称。
回溯算法其实就是一个不断探索尝试的过程,探索成功了也就成功了,探索失败了就在退一步,继续尝试……,并不高效
回溯法,一般可以解决如下几种问题:
- 组合问题:N个数里面按一定规则找出k个数的集合
- 切割问题:一个字符串按一定规则有几种切割方式
- 子集问题:一个N个数的集合里有多少符合条件的子集
- 排列问题:N个数按一定规则全排列,有几种排列方式
- 棋盘问题:N皇后,解数独等等
一般来说:组合问题和排列问题是在树形结构的叶子节点上收集结果,而子集问题就是取树上所有节点的结果。
如果是一个集合来求组合的话,就需要startIndex
如果是多个集合取组合,各个集合之间相互不影响,那么就不用startIndex
回溯法解决的问题都可以抽象为树形结构,因为回溯法解决的都是在集合中递归查找子集,集合的大小就构成了树的宽度,递归的深度,都构成的树的深度。
回溯法抽象为树形结构后,其遍历过程就是:for循环横向遍历,递归纵向遍历,回溯不断调整结果集。简言之,横向遍历,纵向递归的组合。
递归就要有终止条件,所以必然是一棵高度有限的树(N叉树)。
为什么要防止分支污染:https://mp.weixin.qq.com/s/yQm53cyiFiJ7NI-lFEa1UA
递归时值传递、引用传递的注意
起始索引,树枝去重、数层去重
回溯算法模板:
void backtracking(参数) {
if (终止条件) {
存放结果;
return;
}
for (选择:本层集合中元素(树中节点孩子的数量就是集合的大小)) {
处理节点;
backtracking(路径,选择列表); // 递归
回溯,撤销处理结果
}
}
组合
/**
* leetcode-77. 组合
*/
List<List<Integer>> res = new ArrayList<>();
List<Integer> list = new ArrayList<>();
public List<List<Integer>> combine(int n, int k) {
dfs(n,k,1);
return res;
}
public void dfs(int n,int k ,int startIndex){
if(list.size() == k){
res.add(new ArrayList<>(list));
return;
}
for (int i = startIndex; i <= n-(k-list.size())+1; i++) {
list.add(i);
dfs(n,k,i+1);
list.remove(list.size()-1); //防止分支污染
}
}
组合总和3
/**
* leetcode-216. 组合总和 III
*/
List<List<Integer>> res = new ArrayList<>();
List<Integer> list = new ArrayList<>();
public List<List<Integer>> combinationSum3(int k, int n) {
dfs(k,n,1);
return res;
}
public void dfs(int k,int n,int startIndex){
if(list.size() == k){
int sum = 0;
for (int i = 0; i < list.size(); i++) {
sum+=list.get(i);
}
if(sum == n){
res.add(new ArrayList<>(list));
}
return;
}
for (int i = startIndex; i <= 9-(k-list.size())+1; i++) {
list.add(i);
dfs(k,n,i+1);
list.remove(list.size()-1);
}
}
//简化
List<List<Integer>> ans = new ArrayList<>();
List<Integer> list = new ArrayList<>();
public List<List<Integer>> combinationSum3(int k, int n) {
dfs(k, n, 1, 0);
return ans;
}
public void dfs(int k, int n, int startIndex, int sum) {
if(sum > n) return;
if (k == list.size()) {
if (sum == n)
ans.add(new ArrayList<>(list));
return;
}
for (int i = startIndex; i <= 9-(k-list.size())+1; i++) {
list.add(i);
dfs(k, n, i + 1, sum + i);
list.remove(list.size() - 1);
}
}
电话号码的字母组合
/**
* leetcode-17.电话号码的字母组合
*/
List<String> ans = new ArrayList<>();
public List<String> letterCombinations(String digits){
if (digits == null || digits.length() == 0) {
return ans;
}
//初始对应所有的数字,为了直接对应2-9,新增了两个无效的字符串""
String[] numString = {"", "", "abc", "def", "ghi", "jkl", "mno", "pqrs", "tuv", "wxyz"};
dfs(digits,numString,0);
return ans;
}
StringBuilder stringBuilder = new StringBuilder();
public void dfs(String digits,String[] numString,int num){
if(num == digits.length()){
ans.add(new String(stringBuilder));
return;
}
String str = numString[digits.charAt(num)-'0'];
for (int i = 0; i < str.length(); i++) {
stringBuilder.append(str.charAt(i));
dfs(digits,numString,num+1);
stringBuilder.deleteCharAt(stringBuilder.length()-1);
}
}
组合总和
/**
* 39. 组合总和
* @param candidates
* @param target
* @return
*/
List<List<Integer>> ans = new ArrayList<>();
List<Integer> list = new ArrayList<>();
public List<List<Integer>> combinationSum(int[] candidates, int target) {
Arrays.sort(candidates); //排序,方便做剪枝操作
dfs(candidates,target,list,0,0);
return ans;
}
public void dfs(int[] candidates,int target,List<Integer> list,int idx,int sum){
if(sum == target){
ans.add(new ArrayList<>(list));
return;
}
//如果是一个集合来求组合的话,就需要startIndex
//如果是多个集合取组合,各个集合之间相互不影响,那么就不用startIndex
for (int i = idx; i < candidates.length; i++) {
if(sum + candidates[i] > target) break; //剪枝
list.add(candidates[i]);
dfs(candidates,target,list,i,sum+candidates[i]); //这里i没有+1,因为可以重复
list.remove(list.size()-1);
}
}
组合总和2
/**
* leetcode-40组合总和2
* @param candidates
* @param target
* @return
*/
List<List<Integer>> ans = new ArrayList<>();
List<Integer> list = new ArrayList<>();
public List<List<Integer>> combinationSum2(int[] candidates, int target) {
Arrays.sort(candidates);
dfs(candidates, target, 0, list, 0);
return ans;
}
public void dfs(int[] candidates, int target, int idx, List<Integer> list, int sum) {
if (sum == target) {
ans.add(new ArrayList<>(list));
return;
}
for (int i = idx; i < candidates.length && sum + candidates[i] <= target; i++) {
if(sum > target) break;
if (i > idx && candidates[i] == candidates[i - 1]) continue;
list.add(i);
dfs(candidates, target, i + 1, list, sum + candidates[i]);
list.remove(list.size() - 1);
}
}
分割回文串
/**
* leetcode-131. 分割回文串
* @param s
* @return
*/
List<List<String>> ans = new ArrayList<>();
List<String> list = new ArrayList<>();
public List<List<String>> partition(String s) {
dfs(s,0);
return ans;
}
public void dfs(String s,int startIndex){
if(startIndex >= s.length()){
ans.add(new ArrayList<>(list));
return;
}
for (int i = startIndex; i < s.length(); i++) {
if(isPalindrome(s,startIndex,i)){
list.add(s.substring(startIndex,i+1));
}else continue;
dfs(s,i+1);
list.remove(list.size()-1);
}
}
public boolean isPalindrome(String s,int start,int end){
System.out.println(s.substring(start,end+1));
for (int i = start,j = end; i <= j; i++,j--) {
if(s.charAt(i) != s.charAt(j)) return false;
}
return true;
}
复原ip地址
/**
* leetcode-93 复原ip地址
* @param s
* @return
*/
List<String> ans = new ArrayList<>();
List<String> list = new ArrayList<>();
public List<String> restoreIpAddresses(String s) {
//长度小于 4 或者大于 12 ,一定不能拼凑出合法的 ip 地址
if (s.length() < 4 || s.length() > 12) return ans;
dfs(s, 0);
return ans;
}
public void dfs(String s, int startIndex) {
if (list.size() == 4 && startIndex >= s.length()) { //分割4次
ans.add(String.join(".", list)); //将list中的元素用.分割并返回字符串
return;
}
for (int i = startIndex; i < s.length(); i++) {
if (isIp(s, startIndex, i)) {
list.add(s.substring(startIndex, i + 1));
dfs(s, i + 1);
list.remove(list.size() - 1);
} else break;
}
}
public boolean isIp(String s, int start, int end) {
// System.out.println(s.substring(start, end + 1));
if (start > end) {
return false;
}
if (s.charAt(start) == '0' && start != end) { // 0开头的数字不合法
return false;
}
int num = 0;
for (int i = start; i <= end; i++) {
if (s.charAt(i) > '9' || s.charAt(i) < '0') { // 遇到⾮数字字符不合法
return false;
}
num = num * 10 + (s.charAt(i) - '0');
if (num > 255) { // 如果⼤于255了不合法
return false;
}
}
return true;
}
子集问题
/**
* 78. 子集
* @param nums
* @return
*/
List<List<Integer>> ans = new ArrayList<>();
List<Integer> list = new ArrayList<>();
public List<List<Integer>> subsets(int[] nums) {
dfs(nums,0);
return ans;
}
public void dfs(int[] nums,int startIndex){
ans.add(new ArrayList<>(list)); //所有节点均添加
if(startIndex >= nums.length){
return;
}
for (int i = startIndex; i < nums.length; i++) {
list.add(nums[i]);
dfs(nums,i+1);
list.remove(list.size()-1);
}
}
子集2
/**
* 90. 子集 II
* @param nums
* @return
*/
List<List<Integer>> ans = new ArrayList<>();
List<Integer> list = new ArrayList<>();
public List<List<Integer>> subsetsWithDup(int[] nums) {
Arrays.sort(nums);
dfs(nums,0);
return ans;
}
public void dfs(int[] nums,int startIndex){
ans.add(new ArrayList<>(list));
if(startIndex >= nums.length) return;
for (int i = startIndex; i < nums.length ; i++) {
if(i>startIndex && nums[i] == nums[i-1]){
continue;
}
list.add(nums[i]);
dfs(nums,i+1);
list.remove(list.size()-1);
}
}
递增子序列
/**
* 491. 递增子序列
* @param nums
* @return
*/
List<List<Integer>> ans = new ArrayList<>();
List<Integer> list = new ArrayList<>();
public List<List<Integer>> findSubsequences(int[] nums) {
dfs(nums,0);
return ans;
}
public void dfs(int[] nums,int startIndex){
if(list.size() >= 2){
ans.add(new ArrayList<>(list));
// 注意这里不要加return,因为取得是树上的每个节点
}
if(startIndex >= nums.length) return;
//used数组只负责本层 新的一层会清空
int[] used = new int[201]; //去重数组,-100 <= nums[i] <= 100
for (int i = startIndex; i < nums.length; i++) {
if((!list.isEmpty() && nums[i] < list.get(list.size()-1)) || used[nums[i]+100] ==1){
// 当前元素小于list中的最后一个元素或者同一层中当前元素已被使用过
continue;
}
used[i+100] = 1;
list.add(nums[i]);
dfs(nums,i+1);
list.remove(list.size()-1);
}
}
全排列
/**
* 46. 全排列
*
* @param nums
* @return
*/
List<List<Integer>> ans = new ArrayList<>();
List<Integer> list = new ArrayList<>();
public List<List<Integer>> permute(int[] nums) {
boolean[] used = new boolean[nums.length];
dfs(nums,used);
return ans;
}
public void dfs(int[] nums, boolean used[]) {
if (list.size() == nums.length) {
ans.add(new ArrayList<>(list));
return;
}
for (int i = 0; i < nums.length; i++) {
if (!used[i]) {
list.add(nums[i]);
} else {
continue;
}
used[i] = true;
dfs(nums, used);
list.remove(list.size() - 1);
used[i] = false;
}
}
全排列2
/**
* 46. 全排列
*
* @param nums
* @return
*/
List<List<Integer>> ans = new ArrayList<>();
List<Integer> list = new ArrayList<>();
public List<List<Integer>> permuteUnique(int[] nums) {
Arrays.sort(nums);
boolean[] used = new boolean[nums.length];
dfs(nums, used);
return ans;
}
//used用于纵向 一个排列里一个元素只能使用一次
public void dfs(int[] nums, boolean used[]) {
if (list.size() == nums.length) {
ans.add(new ArrayList<>(list));
return;
}
for (int i = 0; i < nums.length; i++) {
// used[i - 1] == true,说明同⼀树⽀nums[i - 1]使⽤过
// used[i - 1] == false,说明同⼀树层nums[i - 1]使⽤过
// 如果同⼀树层nums[i - 1]使⽤过则直接跳过
if (i > 0 && nums[i] == nums[i - 1] && !used[i - 1]) {
continue;
}
if (!used[i]) {
list.add(nums[i]);
used[i] = true;
dfs(nums, used);
list.remove(list.size() - 1);
used[i] = false;
}
}
}
重新安排行程
/**
* 332. 重新安排行程
* @param tickets
* @return
*/
List<String> list = new ArrayList<>();
List<String> ans;
public List<String> findItinerary(List<List<String>> tickets) {
Collections.sort(tickets, new Comparator<List<String>>() { //可以用lambda表达式优化
@Override
public int compare(List<String> o1, List<String> o2) {
return o1.get(1).compareTo(o2.get(1));
}
});
boolean[] used = new boolean[tickets.size()];
list.add("JFK");
dfs(tickets,used);
return ans;
}
public boolean dfs(List<List<String>> tickets,boolean[] used){
if(list.size() == tickets.size()+1){
ans = new ArrayList<>(list);
return true;
}
for (int i = 0; i < tickets.size(); i++) {
if(!used[i] && tickets.get(i).get(0).equals(list.get(list.size()-1))){
list.add(tickets.get(i).get(1));
used[i] = true;
if(dfs(tickets,used)) return true; //返回 找到一条线路就可以
list.remove(list.size()-1);
used[i] =false;
}
}
return false;
}
N皇后
/**
* 51. N 皇后
* @param n
* @return
*/
List<List<String>> ans = new ArrayList<>();
public List<List<String>> solveNQueens(int n) {
char[][] chessboard = new char[n][n];
for (char[] chars : chessboard) {
Arrays.fill(chars,'.');
}
dfs(chessboard,0,n);
return ans;
}
public void dfs(char[][] chessboard, int row, int n) {
if (row == n) {
ans.add(charArrayToList(chessboard));
return;
}
for (int column = 0; column < n; column++) {
if (verify(chessboard,row,column)){
chessboard[row][column] = 'Q';
dfs(chessboard,row+1,n);
chessboard[row][column] = '.';
}
}
}
public boolean verify(char[][] chessboard, int row, int column) {
//不需要验证行,每次向下递归走一行生成一个皇后 行之间肯定不会重复
//验证列
int n = chessboard.length;
for (int i = 0; i < n; i++) {
if (chessboard[i][column] == 'Q') {
return false;
}
}
//验证45° 即'/'
for (int i = row - 1, j = column + 1; i >= 0 && j < n; i--,j++) {
if (chessboard[i][j] == 'Q') {
return false;
}
}
//验证135° 即'\'
for (int i = row - 1, j = column - 1; i >= 0 && j >= 0; i--,j--) {
if(chessboard[i][j] == 'Q'){
return false;
}
}
return true;
}
public List<String> charArrayToList(char[][] chessboard) {
List<String> list = new ArrayList<>();
for (char[] chars : chessboard) {
list.add(String.valueOf(chars));
}
return list;
}
解数独
/**
* 37. 解数独
* @param board
*/
public void solveSudoku(char[][] board) {
dfs(board);
}
public boolean dfs(char[][] board) { //和n皇后不同,要遍历每行每列
//一个for循环遍历棋盘的行,一个for循环遍历棋盘的列,
// 一行一列确定下来之后,递归遍历这个位置放9个数字的可能性!」
for (int i = 0; i < 9; i++) {
for (int j = 0; j < 9; j++) {
if (board[i][j] != '.') continue; //本来有数字 跳过
for ( char k = '1'; k <= '9'; k++) {
if(verify(board,k,i,j)){
board[i][j] = k;
if(dfs(board)) return true; // 如果找到合适一组立刻返回
board[i][j] = '.';
}
}
return false; //9个数字都试完了 返回false 所以没有递归终止条件也行
}
}
return true;
}
public boolean verify(char[][] board, char val, int row, int col) {
// 同行是否重复
for (int i = 0; i < 9; i++) {
if (board[row][i] == val) return false;
}
//同列是否重复
for (int i = 0; i < 9; i++) {
if (board[i][col] == val) return false;
}
//同九宫格内是否重复
int startRow = (row / 3) * 3;
int startCol = (col / 3) * 3;
for (int i = startRow; i < startRow + 3; i++) {
for (int j = startCol; j < startCol + 3; j++) {
if (board[i][j] == val) return false;
}
}
return true;
}