回溯算法理论基础
回溯算法解决的问题都可以抽象为树形结构(N叉树),用树形结构来理解回溯会容易很多。
回溯法一般可以解决如下几种问题:
- 组合问题:N个数里面按一定规则找出K个数的集合
- 切割问题:一个字符串按一定的规则有几种切割方式
- 子集问题:一个N个数的集合里有多少符合条件的子集
- 排列问题:N个数按一定规则全排列,有几种排列方式
- 棋盘问题:N皇后,解数独等
回溯算法模板:
-
回溯算法的返回值和参数
回溯算法中返回值一般为空,参数更具后续的逻辑来,先写逻辑,再写参数,需要什么参数就填什么参数 -
回溯终止条件:存放结果
-
回溯的遍历过程:处理节点, for(调用自己, 回溯(递归),撤销处理结果)。for为横向遍历,递归为纵向遍历
组合
# 不剪枝
class Solution:
def combine(self, n: int, k: int) -> List[List[int]]:
self.res = [] # 存放符合条件结果的集合
self.combine = [] # 用来存放符合条件结果
self.backtrack(n, k , 1)
return self.res
def backtrack(self, n, k, start_index):
if len(self.combine) == k:
# 这里要append的是self.combine里面的值而不是self.combine,因为self.combine在循环里被pop()最后都是空的
self.res.append(self.combine[:])
return
for i in range(start_index, n + 1):
self.combine.append(i) # 处理节点
self.backtrack(n, k , i + 1) # 递归
self.combine.pop() # 回溯,撤销处理的节点
# 剪枝
class Solution:
def combine(self, n: int, k: int) -> List[List[int]]:
self.res = []
self.combine = []
self.backtrack(n, k , 1)
return self.res
def backtrack(self, n, k, start_index):
if len(self.combine) == k:
# 这里要append的是self.combine里面的值而不是self.combine,因为self.combine在循环里被pop()最后都是空的
self.res.append(self.combine[:])
return
for i in range(start_index, n - (k - len(self.combine)) + 2): # 剪枝
self.combine.append(i)
self.backtrack(n, k , i + 1)
self.combine.pop()
组合总数III
class Solution:
def combinationSum3(self, k: int, n: int) -> List[List[int]]:
self.res = []
self.combine = []
self.backtrack(n, k, 1)
return self.res
def backtrack(self, n, k, start_index):
if sum(self.combine) > n:
return
# 这里把计算sum放前面可能会超时,因为每个都要计算,所以先计算长度
if len(self.combine) == k and sum(self.combine) == n:
self.res.append(self.combine[:])
return
for i in range(start_index, 9 - (k - len(self.combine)) + 1 + 1):
self.combine.append(i)
self.backtrack(n, k, i + 1)
self.combine.pop()
电话号码的字母组合
class Solution:
def letterCombinations(self, digits: str) -> List[str]:
if not digits: return []
self.res = []
self.answer = ''
self.number = list(digits)
self.map = {'2': 'abc',
'3': 'def',
'4': 'ghi',
'5': 'jkl',
'6': 'mno',
'7': 'pqrs',
'8': 'tuv',
'9': 'wxyz'}
self.backtrack(0)
return self.res
def backtrack(self, index):
if index == len(self.number):
self.res.append(self.answer[:])
return
letters = self.map[self.number[index]]
for letter in letters:
self.answer += letter
self.backtrack(index + 1)
self.answer = self.answer[:-1]
组合总和
class Solution:
def combinationSum(self, candidates: List[int], target: int) -> List[List[int]]:
if target < all(_ for _ in candidates):
return []
self.len = len(candidates)
self.result = []
self.answer = []
self.sum = 0
self.backtrack(candidates, target, 0)
return self.result
def backtrack(self, candidates, target, start_index):
if self.sum == target:
self.result.append(self.answer[:])
return
if self.sum > target:
return
for i in range(start_index, self.len):
self.answer.append(candidates[i])
self.sum += candidates[i]
# 这里你可以从该点选无数次,但是之后要从后面选,不能选了后面的之后又从前面选,这样就会重复
# 如果不从i开始,会有重复的answer出现
self.backtrack(candidates, target, i)
self.sum -= candidates[i]
self.answer.pop()
组合总和ii
class Solution:
def combinationSum2(self, candidates: List[int], target: int) -> List[List[int]]:
if target < all(_ for _ in candidates):
return []
self.len = len(candidates)
self.result = []
self.answer = []
self.sum = 0
# 提前排序
candidates.sort()
self.backtrack(candidates, target, 0)
return self.result
def backtrack(self, candidates, target, start_index):
if self.sum == target:
self.result.append(self.answer[:])
return
if self.sum > target:
return
for i in range(start_index, self.len):
# 不能有相同的组合,意味着同一层不能使用相同的元素
# 跳过同一层中已经使用过的元素
if i > start_index and candidates[i] == candidates[i - 1]:
continue
self.answer.append(candidates[i])
self.sum += candidates[i]
self.backtrack(candidates, target, i + 1)
self.sum -= candidates[i]
self.answer.pop()
# 用used去重
class Solution:
def combinationSum2(self, candidates: List[int], target: int) -> List[List[int]]:
if target < all(_ for _ in candidates):
return []
self.len = len(candidates)
self.result = []
self.answer = []
self.sum = 0
self.used = [False] * len(candidates)
candidates.sort()
self.backtrack(candidates, target, 0)
return self.result
def backtrack(self, candidates, target, start_index):
if self.sum == target:
self.result.append(self.answer[:])
return
if self.sum > target:
return
for i in range(start_index, self.len):
# 检查同一树层是否出现曾经使用过的相同元素
# 若数组中前后元素值相同,但前者却未被使用(used == False),说明是for loop中的同一树层的相同元素情况
if i > 0 and candidates[i] == candidates[i - 1] and self.used[i - 1] == False:
continue
self.answer.append(candidates[i])
self.sum += candidates[i]
self.used[i] = True
self.backtrack(candidates, target, i + 1)
self.used[i] = False
self.sum -= candidates[i]
self.answer.pop()
分割回文串
class Solution:
def partition(self, s: str) -> List[List[str]]:
self.res = []
self.answer = []
self.backtrack(s, 0)
return self.res
def backtrack(self, s, start_index):
# 这里只answer只存回文,如果不是回文就continue,所有当start_index走到最后,answer里的都是回文,而且已经全部切割完成了
if start_index >= len(s):
self.res.append(self.answer[:])
return
for i in range(start_index, len(s)):
if self.is_palindrome(s, start_index, i):
self.answer.append(s[start_index: i + 1])
self.backtrack(s, i + 1)
self.answer.pop()
else:
continue
def is_palindrome(self, s, start, end):
left = start
right = end
while left < right:
if s[left] != s[right]:
return False
left += 1
right -= 1
return True
复原IP地址
class Solution:
def restoreIpAddresses(self, s):
if len(s) > 16:
return []
self.res = []
self.answer = ''
self.backtrack(s, 0, 0)
return self.res
def backtrack(self, s, start_index, times):
if times == 3:
if len(s[start_index:]) > 3:
return
self.answer += s[start_index:]
if self.is_vaild_ip(self.answer):
self.res.append(self.answer[:])
return
for i in range(start_index, min(len(s), start_index+3)):
self.answer += s[start_index: i + 1] + '.'
self.backtrack(s, i + 1, times + 1)
self.answer = self.answer[: start_index + times]
def is_vaild_ip(self, s):
ips = list(s.split('.'))
for ip in ips:
if ip:
if int(ip) > 255:
return False
if len(ip) != 1 and ip[0] == '0':
return False
else:
return False
return True
子集问题
class Solution:
def subsets(self, nums: List[int]) -> List[List[int]]:
self.res = []
self.answer = []
self.backtrack(nums, 0)
return self.res
def backtrack(self, nums, start_index):
self.res.append(self.answer[:])
if start_index == len(nums):
return
for i in range(start_index, len(nums)):
self.answer.append(nums[i])
self.backtrack(nums, i + 1)
self.answer.pop()
子集问题II
class Solution:
def subsetsWithDup(self, nums: List[int]) -> List[List[int]]:
self.res = []
self.answer = []
nums.sort()
self.backtrack(nums, 0)
return self.res
def backtrack(self, nums, start_index):
self.res.append(self.answer[:])
if start_index == len(nums):
return
for i in range(start_index, len(nums)):
if i > start_index and nums[i] == nums[i - 1]:
continue
self.answer.append(nums[i])
self.backtrack(nums, i + 1)
self.answer.pop()
递增序列
用useg去重
class Solution:
def findSubsequences(self, nums: List[int]) -> List[List[int]]:
self.res = []
self.answer = []
self.backtrack(nums, 0)
return self.res
def backtrack(self, nums, start_index):
if len(self.answer) > 1:
self.res.append(self.answer[:])
if start_index == len(nums):
return
useg = set()
for i in range(start_index, len(nums)):
if self.answer and nums[i] < self.answer[-1] or nums[i] in useg:
continue
useg.add(nums[i])
self.answer.append(nums[i])
self.backtrack(nums, i + 1)
self.answer.pop()
全排列
class Solution:
def permute(self, nums: List[int]) -> List[List[int]]:
self.res = []
self.answer = []
self.backtrack(nums)
return self.res
def backtrack(self, nums):
if not nums:
self.res.append(self.answer[:])
for i in range(len(nums)):
self.answer.append(nums[i])
_pop = nums.pop(i)
self.backtrack(nums)
nums.insert(i, _pop)
self.answer.pop()
最优:
class Solution:
def permute(self, nums: List[int]) -> List[List[int]]:
self.res = []
self.answer = []
self.backtrack(nums)
return self.res
def backtrack(self, nums):
if len(nums) == len(self.answer):
self.res.append(self.answer[:])
for i in range(len(nums)):
if nums[i] in self.answer:
continue
self.answer.append(nums[i])
self.backtrack(nums)
self.answer.pop()
class Solution:
def permute(self, nums: List[int]) -> List[List[int]]:
self.res = []
self.answer = []
self.useg = []
self.backtrack(nums)
return self.res
def backtrack(self, nums):
if len(nums) == len(self.answer):
self.res.append(self.answer[:])
for i in range(len(nums)):
if nums[i] in self.useg:
continue
self.useg.append(nums[i])
self.answer.append(nums[i])
self.backtrack(nums)
self.answer.pop()
self.useg.pop()
全排列II
class Solution:
def permuteUnique(self, nums: List[int]) -> List[List[int]]:
self.res = []
self.answer = []
nums.sort()
self.used = [0] * len(nums)
self.backtrack(nums)
return self.res
def backtrack(self, nums):
if len(nums) == len(self.answer):
self.res.append(self.answer[:])
return
for i in range(len(nums)):
if not self.used[i]:
# 这里的意思是每种值按顺序选取,但是选取值的顺序随机
if i > 0 and nums[i] == nums[i - 1] and not self.used[i - 1]:
continue
self.used[i] = 1
self.answer.append(nums[i])
self.backtrack(nums)
self.answer.pop()
self.used[i] = 0
重新安排行程
from collections import defaultdict
class Solution:
def findItinerary(self, tickets: List[List[str]]) -> List[str]:
self.path = ['JFK']
self.len_tickets = len(tickets)
self.ticket_dict = defaultdict(list)
for _, ticket in enumerate(tickets):
self.ticket_dict[ticket[0]].append(ticket[1])
self.Travel('JFK')
return self.path
def Travel(self, start):
if len(self.path) == self.len_tickets + 1:
return True
self.ticket_dict[start].sort()
for _ in self.ticket_dict[start]:
end = self.ticket_dict[start].pop(0)
self.path.append(end)
if self.Travel(end): # 这里找到一个就返回 path, 设计的就很巧妙
return True
self.path.pop()
self.ticket_dict[start].append(end)
N皇后
class Solution:
def solveNQueens(self, n: int) -> List[List[str]]:
self.n = n
self.result = []
# 这里有个问题:不能写成[['.'] * n] * n, 不然后面赋值的时候会出问题
self.chess_table = [['.'] * n for _ in range(n)]
self.backtrack(0)
return self.result
def backtrack(self, row):
if row == self.n:
answer = []
for item in self.chess_table:
answer.append(''.join(item))
self.result.append(answer[:])
return
for col in range(self.n):
if not self.is_vaild(row, col):
continue
self.chess_table[row][col] = 'Q'
self.backtrack(row + 1)
self.chess_table[row][col] ='.'
def is_vaild(self, row, col):
# 列
for i in range(self.n):
if self.chess_table[i][col] == 'Q':
return False
# 左上角
i, j = row - 1, col - 1
while i >=0 and j >=0 :
if self.chess_table[i][j] == 'Q':
return False
i -= 1
j -= 1
# 右上角
i, j = row - 1, col + 1
while i >=0 and j < self.n:
if self.chess_table[i][j] == 'Q':
return False
i -= 1
j += 1
return True
解数独
class Solution:
def solveSudoku(self, board: List[List[str]]) -> None:
"""
Do not return anything, modify board in-place instead.
"""
self.backtrack(board)
def backtrack(self, board):
for i in range(9):
for j in range(9):
if board[i][j] != '.':
continue
for num in range(1, 10):
if self.is_vaild(i, j, num, board):
board[i][j] = str(num)
if self.backtrack(board): return True
board[i][j] = '.'
# 若1到9填入都没有用,则无解
return False
# 如果走到了最后一个就返回 True
return True
def is_vaild(self, row, col, num, board):
for i in range(9):
if board[row][i] == str(num):
return False
for i in range(9):
if board[i][col] == str(num):
return False
space_row = row // 3
space_col = col // 3
for i in range(space_row * 3, (space_row + 1) * 3):
for j in range(space_col * 3, (space_col + 1) * 3):
if board[i][j] == str(num):
return False
return True