简介
这是一种在图形平面上,有多个节点的路径,求出最低通过成本的算法。常用于游戏中的NPC的移动计算,或线上游戏的BOT的移动计算上。
搜索过程
我们以上面的图片为例,当前要解决的问题为:得到绿点到红点的最短距离。
- 首先我们应该把这个图片转换为计算机可以识别处理的形式 -> 矩阵
- 现在我们需要定义每一步执行的操作。对于绿色点S,它可以上下左右的移动,但是不能超过边界,也不能碰到蓝色障碍物。
- 开始搜索我们就需要执行以下操作:
(a)创建opened_list
和closed_list
分别存放需要搜索节点和已经处理过的节点
(b)将起点加入opened_list
。
(c)取出opened_list
中 f − l o s s f-loss f−loss最小的节点,通过第二部操作找到周围所有状态也就是S附近的8个节点。这里就需要引入A*算法中F-loss的概念。
A*算法通过启发函数
f
(
n
)
=
g
(
n
)
+
w
(
n
)
f(n) = g(n) + w(n)
f(n)=g(n)+w(n)来指引正确的扩展方向
这里的
f
(
n
)
f(n)
f(n)就是上文的
f
−
l
o
s
s
f-loss
f−loss值。
g
(
n
)
g(n)
g(n)指的是从初始节点到节点 n的实际代价,在这个问题中也就是绿色点S移动到当前位置的步数。
w
(
n
)
w(n)
w(n)节点n到目标节点的最佳路径的估计代价,这个问题中可以把他设定为当前节点到目标红色点的曼哈顿距离
∣
y
1
−
y
2
∣
+
∣
x
1
−
x
2
∣
|y1 - y2| + |x1 - x2|
∣y1−y2∣+∣x1−x2∣
(d)计算周围八个节点的f-loss
值。
(e)把这个节点移入closed_list
(f)找到
f
−
l
o
s
s
f-loss
f−loss最小的节点,对他按照如下规则进行检查。
如果它在
closed_list
中,或者是不可走,忽略它;
如果它不在opened_list
中,将其加入opened_list
,并且将当前方格设置为它的父亲节点,记录这个方格的G、W和F值
如果它已经在opened_list
中,判断这个节点和opened_list
存在的节点哪个更好。上文提到过启发函数中 g ( n ) g(n) g(n)指的是从初始节点到节点 n的实际代价,那么 g ( n ) g(n) g(n)越小的节点就更优。将它的父亲节点重新设置为当前方格节点。
(g)重复步骤(c)- (f),就可以最优路径。
八数码问题
这里我引入八数码这个很经典的问题来详细说明A*算法
介绍
在3×3的棋盘上,摆有八个棋子,每个棋子上标有1至8的某一数字。棋盘中留有一个空格,空格用0来表示。空格周围的棋子可以移到空格中。要求解的问题是:给出一种初始布局和目标布局找到一种最少步骤的移动方法,实现从初始布局到目标布局的转变。
状态空间
对于八数码问题,我们可以很直观的看出。可以把 3 × 3 3 × 3 3×3的格子转换为二维数组。
def __init__(self):
self.start_state = [[2, 8, 3], [1, 6, 4], [7, 0, 5]]
self.end_state = [[1, 2, 3], [8, 0, 4], [7, 6, 5]]
self.opened = []
self.closed = []
操作
二维数组的所有操作就是将空方格向上下左右移动。
class Direction(int, Enum):
UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3
@staticmethod
def move(state: List[List[int]], direction: int) -> List[List[int]]:
"""move zero"""
def find_zero() -> Tuple[int, int]:
"""find col and row where zero locates"""
for idx, row in enumerate(state):
if 0 in row:
return idx, row.index(0)
result = copy.deepcopy(state)
y, x = find_zero()
if direction == Direction.UP:
if y == 0:
# raise Exception("can't move up")
return result
result[y][x], result[y - 1][x] = result[y - 1][x], result[y][x]
elif direction == Direction.RIGHT:
if x == 2:
# raise Exception("can't move right")
return result
result[y][x], result[y][x + 1] = result[y][x + 1], result[y][x]
elif direction == Direction.DOWN:
if y == 2:
# raise Exception("can't move down")
return result
result[y][x], result[y + 1][x] = result[y + 1][x], result[y][x]
elif direction == Direction.LEFT:
if x == 0:
# raise Exception("can't move left")
return result
result[y][x], result[y][x - 1] = result[y][x - 1], result[y][x]
return result
节点
现在我们需要为每一步操作都定义一个节点来记录。首先我们必须记录一下信息:当前状态空间、 g ( n ) g(n) g(n)、父节点、 f − l o s s f-loss f−loss。记录 g ( n ) g(n) g(n)是为了比较两个相同状态空间的节点哪一个更优(上文有提及),记录父节点是为了可以得到完整路径,记录 f − l o s s f-loss f−loss是为了排序选择节点。
def create_node(self, state, step, parent) -> list:
"""create node"""
return state, step, parent, self.cal_w(state) + step
def cal_w(self, state: List[List[int]]) -> int:
"""calculate w(n): number of digits with wrong position"""
number = 0
for y in range(3):
for x in range(3):
if state[y][x] != 0 and state[y][x] != self.end_state[y][x]:
number += 1
return number
这个问题中我们把 w ( n ) w(n) w(n)设定为当前状态矩阵和目标状态矩阵有几个节点不同。
主函数
主函数中执行的就是全部步骤(a)->(g)。
def main(self):
"""main function"""
start_node = self.create_node(self.start_state, 0, None)
self.opened.append(start_node)
con = 0
while self.opened:
node = self.opened[0]
self.opened = self.opened[1:]
if node[0] == self.end_state:
return con, node
self.closed.append(node[0])
for direction in Direction:
temp = eight_figures.move(node[0], direction)
child_node = self.create_node(temp, node[1], node)
if child_node[0] not in self.closed:
self.update_open(child_node)
self.opened.sort(key=lambda x: (x[3], -x[1]))
con += 1
这里需要提及代码中的lambda表达式
lambda x: (x[3], -x[1])
,排序opened_list
我们先对 f − l o s s f-loss f−loss进行升序,再对 w ( n ) w(n) w(n)进行降序
代码中的update_open
函数就是步骤(f)中对节点进行检查,详细可以回顾上文。
def update_open(self, cnt_node: list) -> bool:
"""update opened queue: judge the node based on its f-loss"""
tmp_opened = self.opened[:]
for idx, node in enumerate(tmp_opened):
if node[0] == cnt_node[0]:
if node[3] <= cnt_node[3]:
return False
else:
tmp_opened[idx] = cnt_node
self.opened = tmp_opened
return True
tmp_opened.append(cnt_node)
self.opened = tmp_opened[:]
return True
结果
最终得到终点时,会返回当前节点。从当前节点不断回溯他的父节点就能得到一条完整的路径。
if __name__ == "__main__":
cls = eight_figures()
con, node = cls.main()
print(f"花费{con+1}轮")
for i in range(con + 1):
print(node[0])
node = node[2]
完整代码
#! /usr/bin/env python
# -*- coding: utf-8 -*-#-
import copy
from typing import Tuple, List
from queue import Queue
from enum import Enum
class Direction(int, Enum):
UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3
class eight_figures:
def __init__(self):
self.start_state = [[2, 8, 3], [1, 6, 4], [7, 0, 5]]
self.end_state = [[1, 2, 3], [8, 0, 4], [7, 6, 5]]
self.opened = []
self.closed = []
@staticmethod
def move(state: List[List[int]], direction: int) -> List[List[int]]:
"""move zero"""
def find_zero() -> Tuple[int, int]:
"""find col and row where zero locates"""
for idx, row in enumerate(state):
if 0 in row:
return idx, row.index(0)
result = copy.deepcopy(state)
y, x = find_zero()
if direction == Direction.UP:
if y == 0:
# raise Exception("can't move up")
return result
result[y][x], result[y - 1][x] = result[y - 1][x], result[y][x]
elif direction == Direction.RIGHT:
if x == 2:
# raise Exception("can't move right")
return result
result[y][x], result[y][x + 1] = result[y][x + 1], result[y][x]
elif direction == Direction.DOWN:
if y == 2:
# raise Exception("can't move down")
return result
result[y][x], result[y + 1][x] = result[y + 1][x], result[y][x]
elif direction == Direction.LEFT:
if x == 0:
# raise Exception("can't move left")
return result
result[y][x], result[y][x - 1] = result[y][x - 1], result[y][x]
return result
def cal_w(self, state: List[List[int]]) -> int:
"""calculate w(n): number of digits with wrong position"""
number = 0
for y in range(3):
for x in range(3):
if state[y][x] != 0 and state[y][x] != self.end_state[y][x]:
number += 1
return number
def create_node(self, state, step, parent) -> list:
"""create node"""
return state, step, parent, self.cal_w(state) + step
@staticmethod
def state2int(state: List[List[int]]) -> Tuple[int, list]:
"""transform the state into a number: like hash"""
result = 0
for row in state:
for val in row:
result = result * 10 + val
return result
def update_open(self, cnt_node: list) -> bool:
"""update opened queue: judge the node based on its f-loss"""
tmp_opened = self.opened[:]
for idx, node in enumerate(tmp_opened):
if node[0] == cnt_node[0]:
if node[3] <= cnt_node[3]:
return False
else:
tmp_opened[idx] = cnt_node
self.opened = tmp_opened
return True
tmp_opened.append(cnt_node)
self.opened = tmp_opened[:]
return True
def main(self):
"""main function"""
start_node = self.create_node(self.start_state, 0, None)
self.opened.append(start_node)
con = 0
while self.opened:
node = self.opened[0]
self.opened = self.opened[1:]
if node[0] == self.end_state:
return con, node
self.closed.append(node[0])
for direction in Direction:
temp = eight_figures.move(node[0], direction)
child_node = self.create_node(temp, node[1], node)
if child_node[0] not in self.closed:
self.update_open(child_node)
self.opened.sort(key=lambda x: (x[3], -x[1]))
con += 1
if __name__ == "__main__":
cls = eight_figures()
con, node = cls.main()
print(f"花费{con+1}轮")
for i in range(con + 1):
print(node[0])
node = node[2]