本文使用 Python 语言实现 A* 算法。
算法流程和原理不赘述。
代码文件结构:
point.py
import sys
class Point(object):
def __init__(self, x: int, y: int):
self.x = x
self.y = y
self.cost = sys.maxsize
self.parent = None
map.py
from typing import Tuple, List
from point import Point
class Map(object):
def __init__(self, width: int, height: int, obstacles: List[Tuple[int, int]] = []):
self.width = width
self.height = height
self.obstacles = [Point(x=osc[0], y=osc[1]) for osc in obstacles]
def is_obstacle(self, i: int, j: int):
for p in self.obstacles:
if i==p.x and j==p.y:
return True
return False
a_star.py
有可视化的代码,最终生成视频,生成后可将中间生成的图片删除。
import os
import sys
import time
from typing import Tuple, List
from matplotlib.patches import Rectangle
import cv2
import glob
from point import Point
from map import Map
class AStar(object):
"""
A* algorithm
"""
def __init__(self, map: Map, origin: Tuple[int, int], target: Tuple[int, int]):
"""
initialise
:param map: map
:param origin: starting point coordinates
:param target: ending point coordinates
"""
self.map = map
self.origin = Point(x=origin[0], y=origin[1])
self.target = Point(x=target[0], y=target[1])
self.open_points = []
self.close_points = []
def _basic_cost(self, point: Point):
"""
basic cost from origin
"""
return abs(point.x - self.origin.x) + abs(point.y - self.origin.y)
def _heuristic_cost(self, point: Point):
"""
estimated cost to target
"""
return abs(point.x - self.target.x) + abs(point.y - self.target.y)
def _total_cost(self, point: Point):
"""
total cost
"""
return self._basic_cost(point) + self._heuristic_cost(point)
def _is_valid_point(self, x: int, y: int):
if x < 0 or y < 0:
return False
if x >= self.map.width or y >= self.map.height:
return False
if self.map.is_obstacle(x, y):
return False
return True
def _in_point_list(self, point: Point, points: List[Point]):
for p in points:
if point.x == p.x and point.y == p.y:
return True
return False
def _in_open_list(self, point: Point):
return self._in_point_list(point, self.open_points)
def _in_close_list(self, point: Point):
return self._in_point_list(point, self.close_points)
def run(self, ax, plt):
"""
run alogrithm and visualise
:param ax: matplotlib.axes._subplots.AxesSubplot
:param plt: matplotlib.pyplot
"""
tms = time.time()
self.origin.cost = 0
self.open_points.append(self.origin)
while True:
idx = self._select_from_open_list()
if idx < 0:
print("No path found, algorithm failed!")
return
point = self.open_points[idx]
rectangle = Rectangle(xy=(point.x, point.y), width=1, height=1, color='cyan')
ax.add_patch(rectangle)
self._save_image(plt)
if point.x == self.target.x and point.y == self.target.y:
return self._build_path(point=point, tms=tms, ax=ax, plt=plt)
del self.open_points[idx]
self.close_points.append(point)
# neighbours
self._process_point(x=point.x - 1, y=point.y, parent=point)
self._process_point(x=point.x, y=point.y - 1, parent=point)
self._process_point(x=point.x + 1, y=point.y, parent=point)
self._process_point(x=point.x, y=point.y + 1, parent=point)
def _save_image(self, plt):
"""
save images to outputs folder
"""
millisecond = int(round(time.time() * 1000))
file_name = './outputs/' + str(millisecond) + '.png'
plt.savefig(file_name)
def _process_point(self, x: int, y: int, parent: Point):
"""
process current point
:param x: x coordinate
:param y: y coordinate
:param parent: current point's parent point
"""
# do nothing for invalid point
if not self._is_valid_point(x, y):
return
# do nothing for visited point
point = Point(x, y)
if self._in_close_list(point):
return
print("process point [{}, {}], cost: {}".format(point.x, point.y, point.cost))
if not self._in_open_list(point):
point.parent = parent
point.cost = self._total_cost(point)
self.open_points.append(point)
def _select_from_open_list(self) -> int:
"""
select the point with least cost from the open list
:return idx_select: the index of the selected point in the open list
"""
idx = 0
idx_select = -1
min_cost = sys.maxsize
for point in self.open_points:
cost = self._total_cost(point)
if cost < min_cost:
min_cost = cost
idx_select = idx
idx += 1
return idx_select
def _build_path(self, point: Point, tms: float, ax, plt):
"""
build the whole path after algorithm terminates
:param point: ending point
:param tms: start time
:param ax: matplotlib.axes._subplots.AxesSubplot
:param plt: matplotlib.pyplot
"""
# get whole path
path = []
while True:
path.insert(0, point)
if point.x == self.origin.x and point.y == self.origin.y:
break
else:
point = point.parent
# visualise
for p in path:
rec = Rectangle(xy=(p.x, p.y), width=1, height=1, color='green')
ax.add_patch(rec)
plt.draw()
self._save_image(plt)
self._merge_video()
tme = time.time()
print("Algorithm finishes in {} s".format(int(tme - tms)))
def _merge_video(self):
"""
merge images to video
"""
# get image files
image_files = []
file_names = []
for file_name in glob.glob('./outputs/*.png'):
file_names.append(file_name)
image = cv2.imread(filename=file_name)
height, width, layers = image.shape
size = (width, height)
image_files.append(image)
# generate video
tm= time.time()
video_path = f'./outputs/{round(tm)}.avi'
fourcc = cv2.VideoWriter_fourcc(*'DIVX')
video = cv2.VideoWriter(video_path, fourcc, 5, size)
for image in image_files:
video.write(image)
video.release()
# delete original image files
for file in file_names:
os.remove(file)
main.py(主程序)
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from map import Map
from a_star import AStar
""" map settings """
width, height = 10, 15
origin, target = (0, 0), (width - 1, height - 1)
obstacles = [(round(width * (1 / 4)), j) for j in range(round(height * (2 / 3)))] + [
(round(width * (1 / 2)), j) for j in range(round(height * (1 / 3)), height)] + [
(round(width * (3 / 4)), j) for j in range(round(height * (2 / 3)))]
map_ = Map(width=width, height=height, obstacles=obstacles)
""" visual settings """
plt.figure(figsize=(5, 5))
ax = plt.gca()
ax.set_xlim([0, map_.width])
ax.set_ylim([0, map_.height])
for i in range(map_.width):
for j in range(map_.height):
if map_.is_obstacle(i, j):
rectangle = Rectangle(xy=(i, j), width=1, height=1, color='gray')
ax.add_patch(rectangle)
else:
rectangle = Rectangle(xy=(i, j), width=1, height=1, edgecolor='gray', facecolor='white')
ax.add_patch(rectangle)
rectangle = Rectangle(xy=origin, width=1, height=1, facecolor='blue')
ax.add_patch(rectangle)
rectangle = Rectangle(xy=target, width=1, height=1, facecolor='red')
ax.add_patch(rectangle)
plt.axis('equal') # set equal scaling
plt.axis('off') # turn off axis lines and labels
plt.tight_layout()
""" algorithm """
a_star = AStar(map=map_, origin=(0, 0), target=(width - 1, height - 1))
a_star.run(ax, plt)