回溯算法
笼统地讲,回溯算法很多时候都应用在“搜索”这类问题上。
回溯的处理思想,有点类似枚举搜索。我们枚举所有的解,找到满足期望的解。为了有规律地枚举所有可能的解,避免遗漏和重复,我们把问题求解的过程分为多个阶段。每个阶段,我们都会面对一个岔路口,我们先随意选一条路走,当发现这条路走不通的时候(不符合期望的解),就回退到上一个岔路口,另选一种走法继续走。
回溯算法的思想非常简单,大部分情况下,都是用来解决广义的搜索问题,也就是,从一组可能的解中,选择出一个满足要求的解。回溯算法非常适合用递归来实现,在实现的过程中,剪枝操作是提高回溯效率的一种技巧。利用剪枝,我们并不需要穷举搜索所有的情况,从而提高搜索效率。
回溯算法可以解决的问题:深度优先搜索、八皇后、0-1 背包问题、图的着色、旅行商问题、数独、全排列、正则表达式匹配等等
八皇后问题
我们有一个 8x8 的棋盘,希望往里放 8 个棋子(皇后),每个棋子所在的行、列、对角线都不能有另一个棋子。
我们把这个问题划分成 8 个阶段,依次将 8 个棋子放到第一行、第二行、第三行……第八行。在放置的过程中,我们不停地检查当前放法,是否满足要求。如果满足,则跳到下一行继续放置棋子;如果不满足,那就再换一种放法,继续尝试。
回溯算法非常适合用递归代码实现
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# 棋盘尺寸
BOARD_SIZE = 8
solution_count = 0
queen_list = [0] * BOARD_SIZE
def eight_queens(cur_column: int):
"""
输出所有符合要求的八皇后序列
用一个长度为8的数组代表棋盘的列,数组的数字则为当前列上皇后所在的行数
:return:
"""
if cur_column >= BOARD_SIZE:
global solution_count
solution_count += 1
# 解
print(queen_list)
else:
for i in range(BOARD_SIZE):
if is_valid_pos(cur_column, i):
queen_list[cur_column] = i
eight_queens(cur_column + 1)
def is_valid_pos(cur_column: int, pos: int) -> bool:
"""
因为采取的是每列放置1个皇后的做法
所以检查的时候不必检查列的合法性,只需要检查行和对角
1. 行:检查数组在下标为cur_column之前的元素是否已存在pos
2. 对角:检查数组在下标为cur_column之前的元素,其行的间距pos - QUEEN_LIST[i]
和列的间距cur_column - i是否一致
:param cur_column:
:param pos:
:return:
"""
i = 0
while i < cur_column:
# 同行
if queen_list[i] == pos:
return False
# 对角线
if cur_column - i == abs(pos - queen_list[i]):
return False
i += 1
return True
if __name__ == '__main__':
print('--- eight queens sequence ---')
eight_queens(0)
print('\n--- solution count ---')
print(solution_count)
0-1 背包
我们有一个背包,背包总的承载重量是 Wkg。现在我们有 n 个物品,每个物品的重量不等,并且不可分割。我们现在期望选择几件物品,装载到背包中。在不超过背包所能装载重量的前提下,如何让背包中物品的总重量最大?
对于每个物品来说,都有两种选择,装进背包或者不装进背包。对于 n 个物品来说,总的装法就有 2^n 种,去掉总重量超过 Wkg 的,从剩下的装法中选择总重量最接近 Wkg 的。
如何才能不重复地穷举出这 2^n 种装法呢?
这里就可以用回溯的方法。我们可以把物品依次排列,整个问题就分解为了 n 个阶段,每个阶段对应一个物品怎么选择。先对第一个物品进行处理,选择装进去或者不装进去,然后再递归地处理剩下的物品。
这里还稍微用到了一点搜索剪枝的技巧,就是当发现已经选择的物品的重量超过 Wkg 之后,我们就停止继续探测剩下的物品。
#!/usr/bin/python
# -*- coding: UTF-8 -*-
from typing import List
# 背包选取的物品列表
picks = []
picks_with_max_value = []
def bag(capacity: int, cur_weight: int, items_info: List, pick_idx: int):
"""
回溯法解01背包,穷举
:param capacity: 背包容量
:param cur_weight: 背包当前重量
:param items_info: 物品的重量和价值信息
:param pick_idx: 当前物品的索引
:return:
"""
# 考察完所有物品,或者在中途已经装满
if pick_idx >= len(items_info) or cur_weight == capacity:
global picks_with_max_value
if get_value(items_info, picks) > \
get_value(items_info, picks_with_max_value):
picks_with_max_value = picks.copy()
else:
item_weight = items_info[pick_idx][0]
if cur_weight + item_weight <= capacity: # 选
picks[pick_idx] = 1
bag(capacity, cur_weight + item_weight, items_info, pick_idx + 1)
picks[pick_idx] = 0 # 不选
bag(capacity, cur_weight, items_info, pick_idx + 1)
def get_value(items_info: List, pick_items: List):
values = [_[1] for _ in items_info]
return sum([a*b for a, b in zip(values, pick_items)])
if __name__ == '__main__':
# [(weight, value), ...]
items_info = [(3, 5), (2, 2), (1, 4), (1, 2), (4, 10)]
capacity = 8
print('--- items info ---')
print(items_info)
print('\n--- capacity ---')
print(capacity)
picks = [0] * len(items_info)
bag(capacity, 0, items_info, 0)
print('\n--- picks ---')
print(picks_with_max_value)
print('\n--- value ---')
print(get_value(items_info, picks_with_max_value))
正则表达式
#!/usr/bin/python
# -*- coding: UTF-8 -*-
is_match = False
def rmatch(r_idx: int, m_idx: int, regex: str, main: str):
global is_match
if is_match:
return
if r_idx >= len(regex): # 正则串全部匹配好了
is_match = True
return
if m_idx >= len(main) and r_idx < len(regex): # 正则串没匹配完,但是主串已经没得匹配了
is_match = False
return
if regex[r_idx] == '*': # * 匹配1个或多个任意字符,递归搜索每一种情况
for i in range(m_idx, len(main)):
rmatch(r_idx+1, i+1, regex, main)
elif regex[r_idx] == '?': # ? 匹配0个或1个任意字符,两种情况
rmatch(r_idx+1, m_idx+1, regex, main)
rmatch(r_idx+1, m_idx, regex, main)
else: # 非特殊字符需要精确匹配
if regex[r_idx] == main[m_idx]:
rmatch(r_idx+1, m_idx+1, regex, main)
if __name__ == '__main__':
regex = 'ab*eee?d'
main = 'abcdsadfkjlekjoiwjiojieeecd'
rmatch(0, 0, regex, main)
print(is_match)
从nums选取n个数的全排列
#!/usr/bin/python
# -*- coding: UTF-8 -*-
from typing import List
permutations_list = [] # 全局变量,用于记录每个输出
def permutations(nums: List, n: int, pick_count: int):
"""
从nums选取n个数的全排列
回溯法,用一个栈记录当前路径信息
当满足n==0时,说明栈中的数已足够,输出并终止遍历
:param nums:
:param n:
:param pick_count:
:return:
"""
if n == 0:
print(permutations_list)
else:
for i in range(len(nums) - pick_count):
permutations_list[pick_count] = nums[i]
nums[i], nums[len(nums) - pick_count - 1] = nums[len(nums) - pick_count - 1], nums[i]
permutations(nums, n-1, pick_count+1)
nums[i], nums[len(nums) - pick_count - 1] = nums[len(nums) - pick_count - 1], nums[i]
if __name__ == '__main__':
nums = [1, 2, 3, 4]
n = 3
print('--- list ---')
print(nums)
print('\n--- pick num ---')
print(n)
print('\n--- permutation list ---')
permutations_list = [0] * n
permutations(nums, n, 0)