动机
在浏览一篇RRT算法介绍文章的时候,在浏览作者的开源代码的时候,发现存在一点小bug,然后代码不太好读,所以参考他的文章,简单重写了一下。
代码
简单来说就是使用各种类去建模RRT中的各个元素,并尽量让单一的类负责单一的功能。
新增类如下:
- MatplotlibVisualizer: 负责可视化地图/障碍/采样点/路径
- RectBarrier: 矩形障碍物,负责可视化数据生成+判断是否存在碰撞
- Node: 节点,包含位置信息和父节点信息
- NodeTree: 节点树,支持新增节点、判断是否已经包含节点、搜索最近节点、输出路径等。
- RRT: 算法实现
以下直接贴源码
import numpy as np
import matplotlib.pyplot as plt
import warnings
import itertools
import random
warnings.filterwarnings('ignore')
class ProblemConfig:
def __init__(self, problem_config: dict) -> None:
self.start_point = problem_config["start_point"]
self.end_point = problem_config["end_point"]
self.step = problem_config["step"]
self.fig_x_width = problem_config["fig_x_width"]
self.fig_y_width = problem_config["fig_y_width"]
self.barrier_x_range = problem_config["barrier_x_range"]
self.barrier_y_range = problem_config["barrier_y_range"]
class PlanningProblem:
def __init__(self, problem_config: ProblemConfig) -> None:
self.visualizer = MatplotlibVisualizer(problem_config.fig_x_width,
problem_config.fig_y_width)
self.barrier = RectBarrier(
problem_config.barrier_x_range, problem_config.barrier_y_range)
self.start_point = np.array(problem_config.start_point)
self.end_point = np.array(problem_config.end_point)
self.rrt_planner = RRT(problem_config)
self.path = None
def search(self):
self.path = self.rrt_planner.search()
def show_fig(self):
self.visualizer.plot_barrier(self.barrier.visual_array)
self.visualizer.plot_start_point(self.start_point)
self.visualizer.plot_end_point(self.end_point)
self.visualizer.plot_path(self.path)
rand_points = np.array(self.rrt_planner.rand_points_set)
self.visualizer.plot_random_set(rand_points)
new_points = np.array(self.rrt_planner.new_point_set)
self.visualizer.plot_new_set(new_points)
self.visualizer.show()
class MatplotlibVisualizer:
def __init__(self, x_width: int, y_width: int):
plt.figure()
plt.xlim(-1, x_width)
plt.ylim(-1, y_width)
plt.xlabel('x')
plt.ylabel('y')
plt.xticks(np.arange(x_width))
plt.yticks(np.arange(y_width))
plt.grid()
def plot_barrier(self, barrier_data: np.ndarray):
''' plot barrier data '''
plt.fill_between(
barrier_data[:, 0], barrier_data[:, 1], barrier_data[:, 2], color='black')
def plot_start_point(self, pos: np.ndarray):
''' plot start point '''
plt.plot(pos[0], pos[1], 'ro')
def plot_end_point(self, pos: np.ndarray):
''' plot end point '''
plt.plot(pos[0], pos[1], marker='o', color='yellow')
def plot_path(self, path: np.ndarray):
''' plot path '''
plt.plot(path[:, 0], path[:, 1], '-', linewidth=2)
def plot_random_set(self, random_set: np.ndarray):
''' plot random set '''
plt.scatter(random_set[:, 0], random_set[:, 1], color='cyan')
def plot_new_set(self, new_set: np.ndarray):
''' plot new set '''
plt.scatter(new_set[:, 0], new_set[:, 1], color='violet')
def show(self):
plt.show()
class RectBarrier:
''' Rectangle shaped barrier '''
def __init__(self, x_range: np.ndarray, y_range: np.ndarray):
self.start_x = x_range[0]
self.end_x = x_range[1]
self.start_y = y_range[0]
self.end_y = y_range[1]
self.resolution = 0.3
self.visual_array = self.generate_visual()
def generate_visual(self):
''' for visualize the barrier '''
resolution = self.resolution
x_range = np.linspace(self.start_x - resolution,
self.end_x + resolution, 100)
y_lower = np.ones_like(x_range) * (self.start_y - resolution)
y_upper = np.ones_like(x_range) * (self.end_y + resolution)
return np.c_[x_range, y_lower, y_upper]
def in_collsion_region(self, point: np.ndarray):
''' check if point is in collision region '''
resolution = self.resolution
x_in = (self.start_x - resolution <=
point[0] <= self.end_x + resolution)
y_in = (self.start_y - resolution <=
point[1] <= self.end_y + resolution)
return (x_in and y_in)
def collsion_exist(self, point_start: np.ndarray, point_end: np.ndarray):
'''
simple barrier since x=constant in our case, check line connects the two points
and make sure the y condinate at x=const is out of y_range;
if more difficult, use smaller sample step from barrier_x1 to barrier_x2
'''
if self.in_collsion_region(point_start) or self.in_collsion_region(point_end):
return True
vec = point_end - point_start
sample_step = 0.1
sample_num = int(np.linalg.norm(vec) / sample_step) + 1
sample_vec = vec / float(sample_num)
for i in range(1, sample_num):
test_point = point_start + i * sample_vec
if self.in_collsion_region(test_point):
return True
return False
class Node:
def __init__(self, pos: np.ndarray):
self.pos = pos
self.father_idx = -1
self.idx = -1
def set_father(self, father_idx: int):
self.father_idx = father_idx
def set_idx(self, idx: int):
self.idx = idx
def is_root(self):
return (self.father_dix > 0)
class NodeTree:
def __init__(self):
self.node_map = {}
self.node_ct = 0
self.end_node = None
def add_start_node(self, node: Node):
node.set_idx(0)
self.node_map[0] = node
self.node_ct += 1
def add_end_node(self, node: Node):
node.set_idx(-1)
self.end_node = node
def add_node(self, node: Node, father_idx: int):
node.set_father(father_idx)
node.set_idx(self.node_ct)
self.node_map[self.node_ct] = node
self.node_ct += 1
def contains(self, pos: np.ndarray) -> bool:
if self.distance_idx(pos, 0) < 1e-3:
# is start point
return True
if self.distance_idx(pos, -1) < 1e-3:
# is end point
return True
for i in range(1, self.node_ct):
if self.distance_idx(pos, i) < 1e-3:
return True
return False
def distance_idx(self, pos: np.ndarray, idx: np.ndarray) -> float:
''' l1-norm distance '''
if idx < 0 or idx >= self.node_ct:
return np.inf
dis = np.linalg.norm(self.node_map[idx].pos - pos, ord=1)
return dis
def get_path(self) -> list:
# end_node + from last appended node to start node
path = [self.end_node.pos]
last_node = self.node_map[self.node_ct - 1]
path.append(last_node.pos)
while last_node.father_idx >= 0:
last_node = self.node_map[last_node.father_idx]
path.append(last_node.pos)
path.reverse()
return np.array(path)
class RRT:
def __init__(self, config: ProblemConfig) -> None:
# workspace and barrier
self.workspace = list(itertools.product(range(config.fig_x_width),
range(config.fig_y_width)))
self.barrier = RectBarrier(
config.barrier_x_range, config.barrier_y_range)
# node tree
self.start_node = Node(np.array(config.start_point))
self.end_node = Node(np.array(config.end_point))
self.node_tree = NodeTree()
self.node_tree.add_start_node(self.start_node)
self.node_tree.add_end_node(self.end_node)
# search points
self.rand_p = np.zeros(2)
self.near_p = np.zeros(2)
self.new_p = np.zeros(2)
self.near_idx = -1
self.rand_points_set = []
self.new_point_set = []
self.step = config.step
def search(self):
search_success = self.search_once()
max_trial = 100
trial_ct = 0
while not (search_success and self.end_search()):
search_success = self.search_once()
trial_ct += 1
if (trial_ct >= max_trial):
break
if not self.end_search():
return []
return self.node_tree.get_path()
def search_once(self):
self.get_rand_point()
self.get_near_point()
max_trial = 100
trial_ct = 0
while (self.barrier.collsion_exist(self.near_p, self.rand_p)
and trial_ct < max_trial):
self.get_rand_point()
self.get_near_point()
trial_ct += 1
if trial_ct >= max_trial:
return False
self.new_p = self.search_direction()
self.rand_points_set.append(self.rand_p.copy())
self.new_point_set.append(self.new_p.copy())
new_node = Node(self.new_p)
self.node_tree.add_node(new_node, self.near_idx)
return True
def end_search(self):
if (self.barrier.collsion_exist(self.new_p, self.end_node.pos)):
return False
return True
def new_rand_point(self):
rand_point = random.sample(self.workspace, 1)[0]
self.rand_p = np.array(rand_point)
def get_rand_point(self) -> bool:
self.new_rand_point()
max_trial = 100
trial_ct = 0
while self.node_tree.contains(self.rand_p) and trial_ct < max_trial:
self.new_rand_point()
trial_ct += 1
return (trial_ct < max_trial)
def get_near_point(self):
dis = np.inf
self.near_idx = -1
for i in range(0, self.node_tree.node_ct):
dis_tmp = self.node_tree.distance_idx(self.rand_p, i)
if dis_tmp < dis:
dis = dis_tmp
self.near_idx = i
self.near_p = self.node_tree.node_map[i].pos
return np.isfinite(dis)
def search_direction(self) -> np.ndarray:
vec = self.rand_p - self.near_p
dis2 = np.linalg.norm(vec, ord=2)
if dis2 < 1e-3:
return np.zeros(2)
return self.near_p + vec / dis2 * self.step
if __name__ == "__main__":
problem = {
"start_point": [6, 4],
"end_point": [17, 5],
"step": 1,
"fig_x_width": 25,
"fig_y_width": 12,
"barrier_x_range": [10, 16],
"barrier_y_range": [2, 8],
}
config = ProblemConfig(problem)
problem = PlanningProblem(config)
problem.search()
problem.show_fig()