(21)函数plot_graph用于绘制图形及其上的节点、边和障碍物。给定图形对象和障碍物列表,它会绘制图形的节点、起点、终点以及图形之间的边,并在图形上添加圆形障碍物。
def plot_graph(graph: Graph, obstacles: list):
"""
绘制图形和障碍物。
:param graph: 要绘制的图形
:param obstacles: 障碍物列表
"""
xes = [pos[0] for id, pos in graph.vertices.items()]
yes = [pos[1] for id, pos in graph.vertices.items()]
plt.scatter(xes, yes, c='gray') # 绘制节点
plt.scatter(graph.start[0], graph.start[1], c='#49ab1f', s=50) # 绘制起点
plt.scatter(graph.goal[0], graph.goal[1], c='red', s=50) # 绘制目标点
edges = [(graph.vertices[id_ver], graph.vertices[child]) for pos_ver, id_ver in graph.id_vertex.items()
for child in graph.children[id_ver]]
for edge in edges:
plt.plot([edge[0][0], edge[1][0]], [edge[0][1], edge[1][1]], c='black', alpha=0.5) # 绘制边
# 绘制障碍物
plt.gca().set_aspect('equal', adjustable='box')
for obstacle in obstacles:
circle = plt.Circle(obstacle[0], obstacle[1], color='black')
plt.gca().add_patch(circle)
plt.xlim(0, graph.width)
plt.ylim(0, graph.height)
(22)函数nearest_node_kdtree用于查找最接近输入节点的节点,并检查是否穿过障碍物。它采用了 KD 树数据结构来快速搜索最近的节点。给定图形对象、节点位置、障碍物列表以及可选的分离树节点列表和 KD 树,函数返回与输入节点最接近且不穿过障碍物的节点的位置和ID。
def nearest_node_kdtree(G: Graph, vertex: tuple, obstacles: list, separate_tree_nodes: list = (), kdtree: cKDTree = None):
"""
检查距离输入节点最近的节点,检查是否穿过障碍物。
:param G: 图(Graph)
:param vertex: 要查找其邻居的节点的位置
:param obstacles: 障碍物列表
:param separate_tree_nodes: 分离树中的节点列表
:param kdtree: 除新节点外所有节点创建的KD树
:return: new_vertex, new_id
"""
try:
id = G.id_vertex[vertex]
return np.array(vertex), id
except KeyError:
closest_id = None
closest_pos = None
nn = 1
while True:
d, i = kdtree.query(vertex, k=nn, workers=-1)
if nn == 1:
closest_pos = kdtree.data[i]
else:
closest_pos = kdtree.data[i[-1]]
closest_id = G.id_vertex[closest_pos[0], closest_pos[1]]
line = Line(vertex, closest_pos)
nn += 1
if not through_obstacle(line, obstacles):
break
elif nn > len(G.vertices):
closest_pos = np.array(vertex)
closest_id = None
break
return closest_pos, closest_id
(23)函数nearest_node用于查找给定节点最近的图中的节点,同时检查是否穿过障碍物。如果输入节点已经是图中的节点之一,则直接返回该节点的位置和ID。否则,它会遍历图中的每个节点,排除分离树中的节点,并计算到输入节点的距离。然后,它将返回距离最近的节点的位置和ID。
def nearest_node(G: Graph, vertex: tuple, obstacles: list, separate_tree_nodes: list = ()):
"""
检查距离输入节点最近的节点,检查是否穿过障碍物。
:param G: 图(Graph)
:param vertex: 要查找其邻居的节点的位置
:param obstacles: 障碍物列表
:param separate_tree_nodes: 分离树中的节点列表
:return: new_vertex, new_id
"""
try:
id = G.id_vertex[vertex]
return vertex, id
except KeyError:
min_distance = float("inf")
new_id = None
new_vertex = None
for ver_id, ver in G.vertices.items():
if ver_id in separate_tree_nodes: continue
line = Line(ver, vertex)
if through_obstacle(line, obstacles): continue
distance = calc_distance(ver, vertex)
if distance < min_distance:
min_distance = distance
new_id = ver_id
new_vertex = ver
return new_vertex, new_id
(24)函数steer用于确定从一个给定点(父节点)到另一个给定点(目标节点)的方向,并返回新节点的位置,该位置位于给定长度的边上。如果两个点之间的距离大于最大允许长度,则返回的新节点将在指定方向上与父节点相距最大长度;否则,返回目标节点的位置。
def steer(to_vertex: tuple, from_vertex: tuple, max_length: float) -> tuple:
"""
返回新节点的位置。从顶点到目标顶点的方向,给定长度。
:param to_vertex: 标记方向的顶点的位置
:param from_vertex: 父顶点的位置
:param max_length: 两个节点之间允许的最大边长
:return: 新节点的位置
"""
distance = calc_distance(to_vertex, from_vertex)
x_vect_norm = (to_vertex[0] - from_vertex[0]) / distance
y_vect_norm = (to_vertex[1] - from_vertex[1]) / distance
x_pos = from_vertex[0] + x_vect_norm * max_length
y_pos = from_vertex[1] + y_vect_norm * max_length
if distance > max_length:
return x_pos, y_pos
return to_vertex
(25)函数check_solution用于检查是否已找到解决方案,即检查新节点是否足够接近目标节点,以判断是否达到了目标。
def check_solution(G: Graph, q_new: tuple, node_radius: int) -> bool:
"""
检查是否已找到解决方案(节点是否足够接近目标节点)。
:param G: 图(Graph)
:param q_new: 要检查的节点
:param node_radius: 节点的半径
:return: 如果找到解决方案,则返回 True,否则返回 False
"""
dist_to_goal = calc_distance(q_new, G.goal) # 检查是否到达目标点
if dist_to_goal < 2 * node_radius:
return True
return False
(26)函数plot_path用于绘制路径,参数分别一个图形对象 G,一个路径的节点ID列表 path,可选的标题字符串 title 和路径的成本 cost。函数plot_path通过连接路径中相邻节点的线段来绘制路径,并在标题中显示路径的成本。
def plot_path(G: Graph, path: list, title: str = "", cost: float = float("inf")):
"""
绘制路径。
:param G: 图(Graph)
:param path: 路径中节点的ID列表
:param title: 图的标题
:param cost: 路径的成本
"""
prev_node = G.goal
for point in path:
plt.plot((prev_node[0], G.vertices[point][0]), (prev_node[1], G.vertices[point][1]), c='#057af7', linewidth=2)
prev_node = G.vertices[point]
plt.title(title + f" cost: {round(cost, 2)}")
(27)函数find_path用于从图中的起始节点找到路径,直到到达指定的根节点。它返回从起始节点到根节点的路径列表以及路径的总成本。
def find_path(G: Graph, from_node: int, root_node: int) -> tuple:
"""
从起始节点找到路径。
:param G: 图
:param from_node: 起始节点
:param root_node: 根节点
:return: 路径,成本
"""
path = []
node = from_node
cost = 0
try:
while node != root_node:
path.append(node)
cost += G.cost[node]
node = G.parent[node]
path.append(root_node)
except Exception:
pass
return path, cost
(28)函数forced_removal用于从图中删除一个随机的无子节点的节点,它分别接受一个图对象(Graph)、不会被删除的节点的ID以及路径中的节点列表作为输入参数,并返回被删除的节点的ID。
def forced_removal(G: Graph, id_new: int, path: list) -> int:
"""
从图中删除一个随机的无子节点的节点。
:param G: 图(Graph)
:param id_new: 不会被删除的节点的ID
:param path: 节点列表中的路径
:return: 被删除的节点的ID
"""
id_last_in_path = -1
if path:
id_last_in_path = path[0]
childless_nodes = [node for node, children in G.children.items() if len(children) == 0] # and node != id_new
id_ver = random.choice(childless_nodes)
while id_ver == id_new or id_ver == id_last_in_path:
id_ver = random.choice(childless_nodes)
G.remove_vertex(id_ver)
return id_ver
(29)函数choose_parent_kdtree用于在给定的搜索半径内查找最优的父节点,以使从起始节点到新节点的成本最小化。它通过在KD树中查找半径内的所有节点,并计算它们到新节点的距离来实现此目的。然后,它检查每个候选节点是否通过障碍物,并比较其到新节点的成本是否比当前最佳边的成本更低。
def choose_parent_kdtree(G: Graph, q_new: tuple, id_new: int, best_edge: tuple,
radius: float, obstacles: list, separate_tree_nodes: list = (), kdtree: cKDTree = None) -> tuple:
"""
找到在成本上最优的节点到起始节点。
:param G: 图(Graph)
:param q_new: 新节点的位置
:param id_new: 新节点的ID
:param best_edge: 到目前为止最佳边
:param radius: 搜索区域的半径
:param obstacles: 障碍物列表
:param separate_tree_nodes: 在分离树中的节点ID列表
:param kdtree: 除新节点之外的所有节点创建的kdtree
:return: 最佳节点的ID
"""
i = kdtree.query_ball_point(q_new, r=radius, workers=-1)
in_radius_pos = kdtree.data[i] # 半径内的点的位置
in_radius_ids = [G.id_vertex[pos[0], pos[1]] for pos in in_radius_pos] # 半径内的点的ID
costs = np.linalg.norm(in_radius_pos - q_new, axis=1)
# new_costs = [G.get_cost(id_in_radius) + costs[] for id_in_radius in in_radius_ids]
for id_ver, vertex, cost in zip(in_radius_ids, in_radius_pos, costs):
line = Line(vertex, q_new)
if through_obstacle(line, obstacles): continue
if G.get_cost(id_new) > G.get_cost(id_ver) + cost:
G.cost[id_new] = cost
best_edge = (id_new, id_ver, cost)
return best_edge
(30)函数choose_parent用于在给定的搜索半径内查找最优的父节点,以使从起始节点到新节点的成本最小化。它通过迭代图中的所有节点来实现此目的,并计算每个节点到新节点的距离。然后,它检查每个候选节点是否通过障碍物,并比较其到新节点的成本是否比当前最佳边的成本更低。
def choose_parent(G: Graph, q_new: tuple, id_new: int, best_edge: tuple,
radius: float, obstacles: list, separate_tree_nodes: list = ()) -> tuple:
"""
寻找到起始节点成本最优的节点。
:param G: 图(Graph)
:param q_new: 新节点的位置
:param id_new: 新节点的ID
:param best_edge: 到目前为止最优的边
:param radius: 搜索区域的半径
:param obstacles: 障碍物列表
:param separate_tree_nodes: 分离树中节点的ID列表
:return: 最优节点的ID
"""
for id_ver, vertex in G.vertices.items(): # 遍历所有顶点
if id_ver == id_new: continue
distance_new_vert = calc_distance(q_new, vertex) # 计算新节点到顶点节点的距离
if round(distance_new_vert, 3) > radius: continue # 如果距离大于搜索半径,则继续
line = Line(vertex, q_new) # 创建从新节点到顶点的直线对象
if through_obstacle(line, obstacles): continue # 如果直线穿过障碍物,则继续
if G.get_cost(id_new) > G.get_cost(id_ver) + distance_new_vert: # 如果从新节点到顶点的成本小于当前成本,则重置顶点到新节点的成本
G.cost[id_new] = distance_new_vert
best_edge = (id_new, id_ver, distance_new_vert)
return best_edge
(31)函数rewire_kdtree实现了RRT_STAR算法的重连过程,用于更新图中节点的连接关系和成本。它根据给定的新节点位置和搜索半径,在搜索范围内查找节点,并检查是否可以通过将这些节点重新连接到新节点来降低其成本。如果发现可以通过重新连接来降低节点的成本,则更新图中相应节点的父节点和子节点,并更新其成本。
def rewire_kdtree(G: Graph, q_new: tuple, id_new: int, radius: float, obstacles: list, kdtree: cKDTree = None):
"""
RRT_STAR算法的重连过程。
:param G: 图(Graph)
:param q_new: 新节点的位置
:param id_new: 新节点的ID
:param radius: 搜索区域的半径
:param obstacles: 障碍物列表
:param kdtree: 除新节点外的所有节点构建的kdtree
"""
i = kdtree.query_ball_point(q_new, r=radius, workers=-1)
# 计算新节点与搜索半径内节点之间的距离
in_radius_pos = kdtree.data[i] # 搜索半径内的点的位置
in_radius_ids = [G.id_vertex[pos[0], pos[1]] for pos in in_radius_pos] # 搜索半径内的点的ID列表
costs = np.linalg.norm(in_radius_pos - q_new, axis=1)
# new_costs = [G.get_cost(id_in_radius) + costs[] for id_in_radius in in_radius_ids]
for id_ver, vertex, cost in zip(in_radius_ids, in_radius_pos, costs):
line = Line(vertex, q_new)
if through_obstacle(line, obstacles): continue
if G.get_cost(id_ver) > G.get_cost(id_new) + cost:
parent = G.parent[id_ver] # 重连节点的父节点
del G.children[parent][G.children[parent].index(id_ver)] # 从其父节点的子节点列表中删除重连节点
G.parent[id_ver] = id_new # 将重连节点的父节点设置为新节点
G.children[id_new].append(id_ver) # 将重连节点添加到新节点的子节点列表中
G.cost[id_ver] = cost
(32)函数rewire实现了RRT_STAR算法的重连过程,用于更新图中节点的连接关系和成本。它遍历所有的节点,排除起始节点和新节点,并在指定的搜索半径内查找节点。对于在搜索范围内的每个节点,它计算新节点与该节点之间的距离,并检查是否可以通过将这些节点重新连接到新节点来降低其成本。如果发现可以通过重新连接来降低节点的成本,则更新图中相应节点的父节点和子节点,并更新其成本。
def rewire(G: Graph, q_new: tuple, id_new: int, radius: float, obstacles: list):
"""
RRT_STAR算法的重连过程。
:param G: 图(Graph)
:param q_new: 新节点的位置
:param id_new: 新节点的ID
:param radius: 搜索区域的半径
:param obstacles: 障碍物列表
"""
for id_ver, vertex in G.vertices.items():
if id_ver == G.id_vertex[G.start]: continue
if id_ver == id_new: continue
distance_new_vert = calc_distance(q_new, vertex)
if distance_new_vert > radius: continue
line = Line(vertex, q_new)
if through_obstacle(line, obstacles): continue
if G.get_cost(id_ver) > G.get_cost(id_new) + distance_new_vert:
parent = G.parent[id_ver] # 重连节点的父节点
del G.children[parent][G.children[parent].index(id_ver)] # 从父节点的子节点中删除重连的节点
G.parent[id_ver] = id_new # 将重连节点的父节点设置为新节点
G.children[id_new].append(id_ver) # 将重连节点添加到新节点的子节点中
G.cost[id_ver] = distance_new_vert
(33)函数get_distance_dict用于计算给定节点与图中其他节点之间的距离,并返回以节点 ID 为键,距离为值的字典
def get_distance_dict(G: Graph, node_to_check: int, indeces_to_check: list[int]) -> dict:
pos = G.vertices[node_to_check]
tree_points_list = [vertex for id_ver, vertex in G.vertices.items()]
tree_points = np.array(tree_points_list)
new_point = np.array(pos).reshape(-1, 2)
x2 = np.sum(tree_points ** 2, axis=1).reshape(-1, 1)
y2 = np.sum(new_point ** 2, axis=1).reshape(-1, 1)
xy = 2 * np.matmul(tree_points, new_point.T)
dists = np.sqrt(x2 - xy + y2.T)
distances = {id_ver: id_and_cost[0] for id_ver, id_and_cost in zip(G.vertices, dists)}
return distances
(34)函数calc_distance用于计算两点之间的欧几里得距离。它接受表示两点坐标的两个元组作为输入,并返回两点之间的距离作为浮点数值。
def calc_distance(p1: tuple, p2: tuple) -> float:
"""
计算两个点之间的距离。
:param p1: 点 1
:param p2: 点 2
:return: 点之间的距离
"""
return np.linalg.norm(np.array(p1) - np.array(p2))