本篇博客主要介绍一下回溯的基本思想。
什么是回溯?
回溯算法实际上是一个类似枚举的搜索尝试过程,主要是在搜索尝试过程中寻找问题的解,当发现已不满足求解条件时,就“回溯”返回,尝试别的路径。回溯法是一种选优搜索法,按选优条件向前搜索,以达到目标。但当探索到某一步时,发现原先选择并不优或达不到目标,就退回一步重新选择,这种走不通就退回再走的思想为回溯法,而满足回溯条件的某个状态的点称为“回溯点”。许多复杂的,规模较大的问题都可以使用回溯法,有“通用解题方法”的美称。
回溯算法也叫试探法,它是一种系统地搜索问题的解的方法。
用回溯法解决问题的一般步骤如下:
- 针对所给问题,定义问题的解空间,它至少包含问题的一个(最优)解;
- 确定易于搜索的解空间结构,使得能用回溯法方便地搜索整个解空间;
- 以深度优先的方式搜索解空间,并且在搜索过程中用剪枝函数避免无效搜索。
几个经典的例子
电话号码的字母组合
class Solution {
private String digits;
private List<String> result = new ArrayList<>();
private String[] maps = {
"", "", "abc", "def", "ghi", "jkl", "mno",
"pqrs", "tuv", "wxyz"
};
public List<String> letterCombinations(String digits) {
if (digits == null || digits.length() == 0) {
return this.result;
}
this.digits = digits;
backTrace(new StringBuilder(), 0);
return this.result;
}
private void backTrace(StringBuilder curStr, int curDepth) {
if (curDepth == this.digits.length()) {
if (curStr.length() != 0) {
this.result.add(new String(curStr));
}
return;
}
int curMapIndex = this.digits.charAt(curDepth) - '0';
char[] curMap = maps[curMapIndex].toCharArray();
for (char ch : curMap) {
backTrace(curStr.append(ch), curDepth + 1);
curStr.deleteCharAt(curStr.length() - 1);
}
}
}
二进制手表
class Solution {
private List<String> result = new ArrayList<>();
public List<String> readBinaryWatch(int num) {
int[] stat = new int[10];
backTrace(num, 0, 0, stat);
return result;
}
private void backTrace(int num, int start, int cnt, int[] stat) {
if (cnt == num) {
int hour = stat[0] * 8 + stat[1] * 4 + stat[2] * 2 + stat[3];
int minute = stat[4] * 32 + stat[5] * 16 + stat[6] * 8 + stat[7] * 4 + stat[8] * 2 + stat[9];
if (hour < 12 && minute < 60) {
String s = String.format("%d:%02d", hour, minute);
this.result.add(s);
}
return;
}
for (int i = start; i <= (9 - (num - cnt) + 1); ++i) {
stat[i] = 1;
backTrace(num, i + 1, cnt + 1, stat);
stat[i] = 0;
}
}
}
组合总和
class Solution {
public List<List<Integer>> combinationSum(int[] candidates, int target) {
List<List<Integer>> solutions = new ArrayList<>();
if (candidates == null || candidates.length == 0) {
return solutions;
}
List<Integer> solution = new ArrayList<>();
int curSum = 0;
backTrace(candidates, solutions, solution, curSum, 0, target);
return solutions;
}
private void backTrace(
int[] candidates, List<List<Integer>> solutions, List<Integer> solution,
int curSum, int prevPosition, int target
) {
if (curSum >= target) {
if (curSum == target) {
solutions.add(new ArrayList<>(solution));
}
return;
}
for (int i = prevPosition; i < candidates.length; ++i) {
if (candidates[i] > target) {
continue;
}
solution.add(candidates[i]);
backTrace(candidates, solutions, solution, curSum + candidates[i], i, target);
solution.remove(solution.size() - 1);
}
}
}
活字印刷
class Solution {
public int numTilePossibilities(String tiles) {
if (tiles.isEmpty()) {
return 0;
}
char[] tileArr = tiles.toCharArray();
HashSet<String> result = new HashSet<>();
int[] book = new int[tiles.length()];
backTrace(tileArr, new StringBuilder(), book, result);
return result.size();
}
private void backTrace(char[] tileArr, StringBuilder curStr, int[] book, HashSet<String> result) {
if (curStr.length() != 0) {
result.add(curStr.toString());
}
for (int i = 0; i < tileArr.length; ++i) {
if (book[i] == 1) {
continue;
}
book[i] = 1;
backTrace(tileArr, curStr.append(tileArr[i]), book, result);
curStr.deleteCharAt(curStr.length() - 1);
book[i] = 0;
}
}
}
N皇后
class Solution {
static class Pair {
public int row;
public int col;
public Pair(int row, int col) {
this.row = row;
this.col = col;
}
}
public List<List<String>> solveNQueens(int n) {
List<List<Pair>> solutions = new ArrayList<>();
List<Pair> solution = new ArrayList<>();
backTrace(solutions, solution, 0, n);
return transResult(solutions, n);
}
private void backTrace(
List<List<Pair>> solutions,
List<Pair> solution, int curRow, int n
) {
if (curRow == n) {
solutions.add(new ArrayList<>(solution));
}
for (int col = 0; col < n; ++col) {
if (isValid(solution, curRow, col)) {
solution.add(new Pair(curRow, col));
backTrace(solutions, solution, curRow + 1, n);
solution.remove(solution.size() - 1);
}
}
}
private boolean isValid(
List<Pair> solution, int row, int col
) {
// 判断当前行尝试的皇后位置是否和前面几行的皇后位置有冲突
for (Pair pair : solution) {
if (
pair.col == col ||
pair.row + pair.col == row + col ||
pair.row - pair.col == row - col
) {
return false;
}
}
return true;
}
private List<List<String>> transResult(
List<List<Pair>> solutions, int n
) {
char[] temp = new char[n];
List<List<String>> ret = new ArrayList<>();
for (int i = 0; i < solutions.size(); ++i) {
List<Pair> curSolution = solutions.get(i);
List<String> solu = new ArrayList<>();
for (int row = 0; row < n; ++row) {
Pair curPair = curSolution.get(row);
for (int col = 0; col < n; ++col) {
if (col == curPair.col) {
temp[col] = 'Q';
continue;
}
temp[col] = '.';
}
solu.add(String.valueOf(temp));
}
ret.add(solu);
}
return ret;
}
}