A* 算法的 Python 实现

本文使用 Python 语言实现 A* 算法。
算法流程和原理不赘述。

代码文件结构:
文件结构

point.py

import sys


class Point(object):
    def __init__(self, x: int, y: int):
        self.x = x
        self.y = y

        self.cost = sys.maxsize

        self.parent = None


map.py

from typing import Tuple, List

from point import Point


class Map(object):
    def __init__(self, width: int, height: int, obstacles: List[Tuple[int, int]] = []):
        self.width = width
        self.height = height

        self.obstacles = [Point(x=osc[0], y=osc[1]) for osc in obstacles]

    def is_obstacle(self, i: int, j: int):
        for p in self.obstacles:
            if i==p.x and j==p.y:
                return True
            
        return False


a_star.py
有可视化的代码,最终生成视频,生成后可将中间生成的图片删除。

import os
import sys
import time
from typing import Tuple, List

from matplotlib.patches import Rectangle
import cv2
import glob

from point import Point
from map import Map


class AStar(object):
    """
    A* algorithm
    """
    def __init__(self, map: Map, origin: Tuple[int, int], target: Tuple[int, int]):
        """
        initialise
        
        :param map:  map
        :param origin:  starting point coordinates
        :param target:  ending point coordinates
        """

        self.map = map

        self.origin = Point(x=origin[0], y=origin[1])
        self.target = Point(x=target[0], y=target[1])

        self.open_points = []
        self.close_points = []

    def _basic_cost(self, point: Point):
        """
        basic cost from origin
        """
        return abs(point.x - self.origin.x) + abs(point.y - self.origin.y)

    def _heuristic_cost(self, point: Point):
        """
        estimated cost to target
        """
        return abs(point.x - self.target.x) + abs(point.y - self.target.y)

    def _total_cost(self, point: Point):
        """
        total cost
        """
        return self._basic_cost(point) + self._heuristic_cost(point)

    def _is_valid_point(self, x: int, y: int):
        if x < 0 or y < 0:
            return False
        if x >= self.map.width or y >= self.map.height:
            return False
        if self.map.is_obstacle(x, y):
            return False
        
        return True

    def _in_point_list(self, point: Point, points: List[Point]):
        for p in points:
            if point.x == p.x and point.y == p.y:
                return True
            
        return False

    def _in_open_list(self, point: Point):
        return self._in_point_list(point, self.open_points)

    def _in_close_list(self, point: Point):
        return self._in_point_list(point, self.close_points)
    
    def run(self, ax, plt):
        """
        run alogrithm and visualise
        
        :param ax:  matplotlib.axes._subplots.AxesSubplot
        :param plt:  matplotlib.pyplot
        """

        tms = time.time()

        self.origin.cost = 0
        self.open_points.append(self.origin)
        while True:
            idx = self._select_from_open_list()
            if idx < 0:
                print("No path found, algorithm failed!")
                return
            point = self.open_points[idx]

            rectangle = Rectangle(xy=(point.x, point.y), width=1, height=1, color='cyan')
            ax.add_patch(rectangle)
            self._save_image(plt)

            if point.x == self.target.x and point.y == self.target.y:
                return self._build_path(point=point, tms=tms, ax=ax, plt=plt)

            del self.open_points[idx]
            self.close_points.append(point)

            # neighbours
            self._process_point(x=point.x - 1, y=point.y, parent=point)
            self._process_point(x=point.x, y=point.y - 1, parent=point)
            self._process_point(x=point.x + 1, y=point.y, parent=point)
            self._process_point(x=point.x, y=point.y + 1, parent=point)

    def _save_image(self, plt):
        """
        save images to outputs folder
        """

        millisecond = int(round(time.time() * 1000))
        file_name = './outputs/' + str(millisecond) + '.png'
        plt.savefig(file_name)

    def _process_point(self, x: int, y: int, parent: Point):
        """
        process current point
        
        :param x:  x coordinate
        :param y:  y coordinate
        :param parent:  current point's parent point
        """

        # do nothing for invalid point
        if not self._is_valid_point(x, y):
            return
        
        # do nothing for visited point
        point = Point(x, y)
        if self._in_close_list(point):
            return
        
        print("process point [{}, {}],  cost: {}".format(point.x, point.y, point.cost))
        if not self._in_open_list(point):
            point.parent = parent
            point.cost = self._total_cost(point)
            self.open_points.append(point)

    def _select_from_open_list(self) -> int:
        """
        select the point with least cost from the open list

        :return idx_select:  the index of the selected point in the open list
        """

        idx = 0
        idx_select = -1
        min_cost = sys.maxsize
        for point in self.open_points:
            cost = self._total_cost(point)
            if cost < min_cost:
                min_cost = cost
                idx_select = idx
            idx += 1

        return idx_select

    def _build_path(self, point: Point, tms: float, ax, plt):
        """
        build the whole path after algorithm terminates
        
        :param point:  ending point
        :param tms:  start time
        :param ax:  matplotlib.axes._subplots.AxesSubplot
        :param plt:  matplotlib.pyplot
        """

        # get whole path
        path = []
        while True:
            path.insert(0, point)
            if point.x == self.origin.x and point.y == self.origin.y:
                break
            else:
                point = point.parent

        # visualise
        for p in path:
            rec = Rectangle(xy=(p.x, p.y), width=1, height=1, color='green')
            ax.add_patch(rec)
            plt.draw()
            self._save_image(plt)

        self._merge_video()

        tme = time.time()
        print("Algorithm finishes in {} s".format(int(tme - tms)))

    def _merge_video(self):
        """
        merge images to video
        """

        # get image files
        image_files = []
        file_names = []
        for file_name in glob.glob('./outputs/*.png'):
            file_names.append(file_name)
            image = cv2.imread(filename=file_name)
            height, width, layers = image.shape
            size = (width, height)
            image_files.append(image)

        # generate video
        tm= time.time()
        video_path = f'./outputs/{round(tm)}.avi'
        fourcc = cv2.VideoWriter_fourcc(*'DIVX')
        video = cv2.VideoWriter(video_path, fourcc, 5, size)
        for image in image_files:
            video.write(image)
        video.release()

        # delete original image files
        for file in file_names:
            os.remove(file)


main.py(主程序)

from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle

from map import Map
from a_star import AStar


""" map settings """

width, height = 10, 15

origin, target = (0, 0), (width - 1, height - 1)
obstacles = [(round(width * (1 / 4)), j) for j in range(round(height * (2 / 3)))] + [
    (round(width * (1 / 2)), j) for j in range(round(height * (1 / 3)), height)] + [
    (round(width * (3 / 4)), j) for j in range(round(height * (2 / 3)))]

map_ = Map(width=width, height=height, obstacles=obstacles)


""" visual settings """

plt.figure(figsize=(5, 5))
ax = plt.gca()
ax.set_xlim([0, map_.width])
ax.set_ylim([0, map_.height])

for i in range(map_.width):
    for j in range(map_.height):
        if map_.is_obstacle(i, j):
            rectangle = Rectangle(xy=(i, j), width=1, height=1, color='gray')
            ax.add_patch(rectangle)
        else:
            rectangle = Rectangle(xy=(i, j), width=1, height=1, edgecolor='gray', facecolor='white')
            ax.add_patch(rectangle)

rectangle = Rectangle(xy=origin, width=1, height=1, facecolor='blue')
ax.add_patch(rectangle)
rectangle = Rectangle(xy=target, width=1, height=1, facecolor='red')
ax.add_patch(rectangle)

plt.axis('equal')  # set equal scaling
plt.axis('off')  # turn off axis lines and labels
plt.tight_layout()


""" algorithm """

a_star = AStar(map=map_, origin=(0, 0), target=(width - 1, height - 1))
a_star.run(ax, plt)

  • 0
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值