A*算法
本文仅供学习记录,侵删
参考资料:
本文章在上面参考文章的基础上,做了简单的修改。
- 参考文章中节点到起点的距离用的是对角距离;本文改为继承父节点的距离+节点到父节点的距离
- 参考文章中没考虑已在open_set中的节点信息更新。如果邻近节点在open_set中,应当比较gcost是否比原来小,如果更小则更新其父节点
主函数
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import random_map
import a_star
plt.figure(figsize=(5, 5))
map = random_map.RandomMap()
ax = plt.gca()
ax.set_xlim([0, map.size])
ax.set_ylim([0, map.size])
for i in range(map.size):
for j in range(map.size):
if map.IsObstacle(i, j):
rec = Rectangle((i, j), width=1, height=1, color='gray')
ax.add_patch(rec)
else:
rec = Rectangle((i, j), width=1, height=1, edgecolor='gray', facecolor='w')
ax.add_patch(rec)
rec = Rectangle((0, 0), width=1, height=1, facecolor='b')
ax.add_patch(rec)
rec = Rectangle((map.size-1, map.size-1), width=1, height=1, facecolor='r')
ax.add_patch(rec)
plt.axis('equal')
plt.axis('off')
plt.tight_layout()
# plt.show()
a_star = a_star.AStar(map)
a_star.RunAndSaveImage(ax, plt)
生成地图并随机生成障碍物
"""
随机生成地图,同时在地图中生成一些障碍物
1. 构造函数,地图默认大小是50*50
2. 设置障碍物的数量未地图大小除以8
3. 调用GenerateObstacle生成随机障碍物
4. 在地图的中间生成一个歇着的障碍物
5. 随机生成几个其他的障碍物
6. 障碍物的方向也是随机的
7. 定义一个方法来判断某个节点是否是障碍物
"""
import numpy as np
import point
class RandomMap:
def __init__(self, size=50):
self.obstacle_point = []
self.size = size
self.obstacle = size // 8
self.GenerateObstacle()
# 产生障碍物
def GenerateObstacle(self):
self.obstacle_point.append(point.Point(self.size // 2, self.size // 2))
self.obstacle_point.append(point.Point(self.size // 2, self.size // 2 - 1))
# 在地图中间添加障碍物
for i in range(self.size // 2 - 4, self.size // 2):
self.obstacle_point.append(point.Point(i, self.size - i))
self.obstacle_point.append(point.Point(i, self.size - i - 1))
self.obstacle_point.append(point.Point(self.size - i, i))
self.obstacle_point.append(point.Point(self.size - i, i - 1))
# 随机添加障碍物
for i in range(self.obstacle - 1):
x = np.random.randint(0, self.size)
y = np.random.randint(0, self.size)
self.obstacle_point.append(point.Point(x, y))
if np.random.rand() > 0.5:
for j in range(self.size // 4):
self.obstacle_point.append(point.Point(x, y + j))
pass
else:
for j in range(self.size // 4):
self.obstacle_point.append(point.Point(x + j, y))
pass
# 判断是否是障碍物
def IsObstacle(self, i, j):
for p in self.obstacle_point:
if i == p.x and j == p.y:
return True
return False
节点信息
"""
可以用来产生障碍物
"""
import sys
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
self.cost = sys.maxsize
self.gCost = sys.maxsize
self.parent = None
A*算法主要部分
import sys
import time
import numpy as np
from matplotlib.patches import Rectangle
import point
import random_map
class AStar:
def __init__(self, map):
self.map = map
self.open_set = []
self.close_set = []
# 到起点的成本,g(n)
def BaseCost(self, p):
# 邻近节点在上下左右是为1,否则为1.4
if abs(p.x - p.parent.x) + abs(p.y - p.parent.y) > 1.5:
baseC = 1.4
else:
baseC = 1
return baseC + p.parent.gCost
# 启发函数,到终点的成本,h(n)
def HeuristicCost(self, p):
x_dis = self.map.size - 1 - p.x
y_dis = self.map.size - 1 - p.y
return x_dis + y_dis + (np.sqrt(2) - 2) * min(x_dis, y_dis)
# 总成本 = g(n) + h(n)
def TotalCost(self, p):
return self.BaseCost(p) + self.HeuristicCost(p)
# 判断点是否是有效点
def IsValidPoint(self, x, y):
if x < 0 or y < 0:
return False
if x >= self.map.size or y >= self.map.size:
return False
return not self.map.IsObstacle(x, y)
# 判断点是否在某个集合中
def IsInPointList(self, p, point_list):
for poi in point_list:
if poi.x == p.x and poi.y == p.y:
return True
return False
# 判断点是否在open_list中
def IsInOpenList(self, p):
return self.IsInPointList(p, self.open_set)
# 判断点是否在close_list中
def IsInCloseList(self, p):
return self.IsInPointList(p, self.close_set)
# 从open_list中找到顶点的位置
def FindIndex(self, x, y, point_list):
for index in range(len(point_list)):
if x == point_list[index].x and y == point_list[index].y:
return index
return False
# 是否是开始节点
def IsStartPoint(self, p):
return p.x == 0 and p.y == 0
# 是否是结束节点
def IsEndPoint(self, p):
return p.x == self.map.size - 1 and p.y == self.map.size - 1
# 运行及记录运行轨迹
def RunAndSaveImage(self, ax, plt):
start_time = time.time()
start_point = point.Point(0, 0)
start_point.cost = 0
start_point.gCost = 0
self.open_set.append(start_point)
while True:
index = self.SelectPointInOpenList()
if index < 0:
print(' No path Found, algorithm failed!!!')
return
p = self.open_set[index]
rec = Rectangle((p.x, p.y), 1, 1, color='c')
ax.add_patch(rec)
# self.SaveImage(plt)
if self.IsEndPoint(p):
return self.BuildPath(p, ax, plt, start_time)
del self.open_set[index]
self.close_set.append(p)
# 邻接节点
x = p.x
y = p.y
self.ProcessPoint(x - 1, y + 1, p)
self.ProcessPoint(x - 1, y, p)
self.ProcessPoint(x - 1, y - 1, p)
self.ProcessPoint(x, y - 1, p)
self.ProcessPoint(x + 1, y + 1, p)
self.ProcessPoint(x + 1, y, p)
self.ProcessPoint(x + 1, y - 1, p)
self.ProcessPoint(x, y + 1, p)
def SaveImage(self, plt):
millis = int(round(time.time() * 1000))
filename = './' + str(millis) + '.png'
plt.savefig(filename)
# 针对每一个节点进行处理:如果是没有处理过的节点,则计算优先级设置父节点,并且添加到open_set中。
def ProcessPoint(self, x, y, parent):
if not self.IsValidPoint(x, y):
return # 无效点不作处理
p = point.Point(x, y)
if self.IsInCloseList(p):
return # close_list中的点不作处理
p.parent = parent
p.gCost = self.BaseCost(p)
p.cost = self.TotalCost(p)
if not self.IsInOpenList(p): # 可能少了点东西
self.open_set.append(p)
print('Process Point [', p.x, ',', p.y, ']', ', cost:', p.cost)
else:
# 写一个找到p在open_list中索引的函数
index = self.FindIndex(x, y, self.open_set)
# 如果邻近节点在open_set,则比较gcost是否比原来更小,如果更小则更新其父节点
if p.gCost < self.open_set[index].gCost:
self.open_set[index] = p
print('Process Point [', p.x, ',', p.y, ']', ', cost:', p.cost)
# 从open_set中找到优先级最高的节点,返回其索引。
def SelectPointInOpenList(self):
index = 0
select_index = -1
min_cost = sys.maxsize
for p in self.open_set:
cost = p.cost
if cost < min_cost:
min_cost = cost
select_index = index
index += 1
return select_index
# 终点往回沿着parent构造结果路径。然后从起点开始绘制结果,结果使用绿色方块,每次绘制一步便保存一个图片。
def BuildPath(self, p, ax, plt, start_time):
path = []
result = []
while True:
path.insert(0, p) # Insert first
if self.IsStartPoint(p):
break
else:
p = p.parent
for p in path:
rec = Rectangle((p.x, p.y), 1, 1, color='g')
ax.add_patch(rec)
plt.draw()
# self.SaveImage(plt)
result.append([p.x, p.y])
self.SaveImage(plt)
print(result)
end_time = time.time()
print('==== Algorithm finish in ', int(end_time - start_time), 'seconds')