简单RRT算法的实现

动机

在浏览一篇RRT算法介绍文章的时候,在浏览作者的开源代码的时候,发现存在一点小bug,然后代码不太好读,所以参考他的文章,简单重写了一下。

代码

简单来说就是使用各种类去建模RRT中的各个元素,并尽量让单一的类负责单一的功能。
新增类如下:

  • MatplotlibVisualizer: 负责可视化地图/障碍/采样点/路径
  • RectBarrier: 矩形障碍物,负责可视化数据生成+判断是否存在碰撞
  • Node: 节点,包含位置信息和父节点信息
  • NodeTree: 节点树,支持新增节点、判断是否已经包含节点、搜索最近节点、输出路径等。
  • RRT: 算法实现

以下直接贴源码

import numpy as np
import matplotlib.pyplot as plt
import warnings
import itertools
import random

warnings.filterwarnings('ignore')


class ProblemConfig:
    def __init__(self, problem_config: dict) -> None:
        self.start_point = problem_config["start_point"]
        self.end_point = problem_config["end_point"]
        self.step = problem_config["step"]
        self.fig_x_width = problem_config["fig_x_width"]
        self.fig_y_width = problem_config["fig_y_width"]
        self.barrier_x_range = problem_config["barrier_x_range"]
        self.barrier_y_range = problem_config["barrier_y_range"]


class PlanningProblem:
    def __init__(self, problem_config: ProblemConfig) -> None:
        self.visualizer = MatplotlibVisualizer(problem_config.fig_x_width,
                                               problem_config.fig_y_width)
        self.barrier = RectBarrier(
            problem_config.barrier_x_range, problem_config.barrier_y_range)
        self.start_point = np.array(problem_config.start_point)
        self.end_point = np.array(problem_config.end_point)

        self.rrt_planner = RRT(problem_config)
        self.path = None

    def search(self):
        self.path = self.rrt_planner.search()

    def show_fig(self):
        self.visualizer.plot_barrier(self.barrier.visual_array)
        self.visualizer.plot_start_point(self.start_point)
        self.visualizer.plot_end_point(self.end_point)
        self.visualizer.plot_path(self.path)
        rand_points = np.array(self.rrt_planner.rand_points_set)
        self.visualizer.plot_random_set(rand_points)
        new_points = np.array(self.rrt_planner.new_point_set)
        self.visualizer.plot_new_set(new_points)
        self.visualizer.show()


class MatplotlibVisualizer:
    def __init__(self, x_width: int, y_width: int):
        plt.figure()
        plt.xlim(-1, x_width)
        plt.ylim(-1, y_width)
        plt.xlabel('x')
        plt.ylabel('y')
        plt.xticks(np.arange(x_width))
        plt.yticks(np.arange(y_width))
        plt.grid()

    def plot_barrier(self, barrier_data: np.ndarray):
        ''' plot barrier data '''
        plt.fill_between(
            barrier_data[:, 0], barrier_data[:, 1], barrier_data[:, 2], color='black')

    def plot_start_point(self, pos: np.ndarray):
        ''' plot start point '''
        plt.plot(pos[0], pos[1], 'ro')

    def plot_end_point(self, pos: np.ndarray):
        ''' plot end point '''
        plt.plot(pos[0], pos[1], marker='o', color='yellow')

    def plot_path(self, path: np.ndarray):
        ''' plot path '''
        plt.plot(path[:, 0], path[:, 1], '-', linewidth=2)

    def plot_random_set(self, random_set: np.ndarray):
        ''' plot random set '''
        plt.scatter(random_set[:, 0], random_set[:, 1], color='cyan')

    def plot_new_set(self, new_set: np.ndarray):
        ''' plot new set '''
        plt.scatter(new_set[:, 0], new_set[:, 1], color='violet')

    def show(self):
        plt.show()


class RectBarrier:
    ''' Rectangle shaped barrier '''

    def __init__(self, x_range: np.ndarray, y_range: np.ndarray):
        self.start_x = x_range[0]
        self.end_x = x_range[1]
        self.start_y = y_range[0]
        self.end_y = y_range[1]
        self.resolution = 0.3
        self.visual_array = self.generate_visual()

    def generate_visual(self):
        ''' for visualize the barrier '''
        resolution = self.resolution
        x_range = np.linspace(self.start_x - resolution,
                              self.end_x + resolution, 100)
        y_lower = np.ones_like(x_range) * (self.start_y - resolution)
        y_upper = np.ones_like(x_range) * (self.end_y + resolution)
        return np.c_[x_range, y_lower, y_upper]

    def in_collsion_region(self, point: np.ndarray):
        ''' check if point is in collision region '''
        resolution = self.resolution
        x_in = (self.start_x - resolution <=
                point[0] <= self.end_x + resolution)
        y_in = (self.start_y - resolution <=
                point[1] <= self.end_y + resolution)
        return (x_in and y_in)

    def collsion_exist(self, point_start: np.ndarray, point_end: np.ndarray):
        ''' 
        simple barrier since x=constant in our case, check line connects the two points
        and make sure the y condinate at x=const is out of y_range;
        if more difficult, use smaller sample step from barrier_x1 to barrier_x2     
        '''
        if self.in_collsion_region(point_start) or self.in_collsion_region(point_end):
            return True
        vec = point_end - point_start
        sample_step = 0.1
        sample_num = int(np.linalg.norm(vec) / sample_step) + 1
        sample_vec = vec / float(sample_num)
        for i in range(1, sample_num):
            test_point = point_start + i * sample_vec
            if self.in_collsion_region(test_point):
                return True
        return False


class Node:
    def __init__(self, pos: np.ndarray):
        self.pos = pos
        self.father_idx = -1
        self.idx = -1

    def set_father(self, father_idx: int):
        self.father_idx = father_idx

    def set_idx(self, idx: int):
        self.idx = idx

    def is_root(self):
        return (self.father_dix > 0)


class NodeTree:
    def __init__(self):
        self.node_map = {}
        self.node_ct = 0
        self.end_node = None

    def add_start_node(self, node: Node):
        node.set_idx(0)
        self.node_map[0] = node
        self.node_ct += 1

    def add_end_node(self, node: Node):
        node.set_idx(-1)
        self.end_node = node

    def add_node(self, node: Node, father_idx: int):
        node.set_father(father_idx)
        node.set_idx(self.node_ct)
        self.node_map[self.node_ct] = node
        self.node_ct += 1

    def contains(self, pos: np.ndarray) -> bool:
        if self.distance_idx(pos, 0) < 1e-3:
            # is start point
            return True
        if self.distance_idx(pos, -1) < 1e-3:
            # is end point
            return True
        for i in range(1, self.node_ct):
            if self.distance_idx(pos, i) < 1e-3:
                return True
        return False

    def distance_idx(self, pos: np.ndarray, idx: np.ndarray) -> float:
        ''' l1-norm distance '''
        if idx < 0 or idx >= self.node_ct:
            return np.inf
        dis = np.linalg.norm(self.node_map[idx].pos - pos, ord=1)
        return dis

    def get_path(self) -> list:
        # end_node + from last appended node to start node
        path = [self.end_node.pos]
        last_node = self.node_map[self.node_ct - 1]
        path.append(last_node.pos)
        while last_node.father_idx >= 0:
            last_node = self.node_map[last_node.father_idx]
            path.append(last_node.pos)
        path.reverse()
        return np.array(path)


class RRT:
    def __init__(self, config: ProblemConfig) -> None:
        # workspace and barrier
        self.workspace = list(itertools.product(range(config.fig_x_width),
                                                range(config.fig_y_width)))
        self.barrier = RectBarrier(
            config.barrier_x_range, config.barrier_y_range)
        # node tree
        self.start_node = Node(np.array(config.start_point))
        self.end_node = Node(np.array(config.end_point))
        self.node_tree = NodeTree()
        self.node_tree.add_start_node(self.start_node)
        self.node_tree.add_end_node(self.end_node)
        # search points
        self.rand_p = np.zeros(2)
        self.near_p = np.zeros(2)
        self.new_p = np.zeros(2)
        self.near_idx = -1
        self.rand_points_set = []
        self.new_point_set = []
        self.step = config.step

    def search(self):
        search_success = self.search_once()
        max_trial = 100
        trial_ct = 0
        while not (search_success and self.end_search()):
            search_success = self.search_once()
            trial_ct += 1
            if (trial_ct >= max_trial):
                break
        if not self.end_search():
            return []
        return self.node_tree.get_path()

    def search_once(self):
        self.get_rand_point()
        self.get_near_point()
        max_trial = 100
        trial_ct = 0
        while (self.barrier.collsion_exist(self.near_p, self.rand_p)
               and trial_ct < max_trial):
            self.get_rand_point()
            self.get_near_point()
            trial_ct += 1
        if trial_ct >= max_trial:
            return False
        self.new_p = self.search_direction()
        self.rand_points_set.append(self.rand_p.copy())
        self.new_point_set.append(self.new_p.copy())
        new_node = Node(self.new_p)
        self.node_tree.add_node(new_node, self.near_idx)
        return True

    def end_search(self):
        if (self.barrier.collsion_exist(self.new_p, self.end_node.pos)):
            return False
        return True

    def new_rand_point(self):
        rand_point = random.sample(self.workspace, 1)[0]
        self.rand_p = np.array(rand_point)

    def get_rand_point(self) -> bool:
        self.new_rand_point()
        max_trial = 100
        trial_ct = 0
        while self.node_tree.contains(self.rand_p) and trial_ct < max_trial:
            self.new_rand_point()
            trial_ct += 1
        return (trial_ct < max_trial)

    def get_near_point(self):
        dis = np.inf
        self.near_idx = -1
        for i in range(0, self.node_tree.node_ct):
            dis_tmp = self.node_tree.distance_idx(self.rand_p, i)
            if dis_tmp < dis:
                dis = dis_tmp
                self.near_idx = i
                self.near_p = self.node_tree.node_map[i].pos
        return np.isfinite(dis)

    def search_direction(self) -> np.ndarray:
        vec = self.rand_p - self.near_p
        dis2 = np.linalg.norm(vec, ord=2)
        if dis2 < 1e-3:
            return np.zeros(2)
        return self.near_p + vec / dis2 * self.step


if __name__ == "__main__":

    problem = {
        "start_point": [6, 4],
        "end_point": [17, 5],
        "step": 1,
        "fig_x_width": 25,
        "fig_y_width": 12,
        "barrier_x_range": [10, 16],
        "barrier_y_range": [2, 8],
    }
    config = ProblemConfig(problem)
    problem = PlanningProblem(config)
    problem.search()
    problem.show_fig()

  • 6
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是一个简单RRT算法的Python实现示例: ```python import numpy as np import matplotlib.pyplot as plt class RRT: def __init__(self, start, goal, obstacles, xlim, ylim, max_iter=1000, step_size=0.5): self.start = start self.goal = goal self.obstacles = obstacles self.xlim = xlim self.ylim = ylim self.max_iter = max_iter self.step_size = step_size self.nodes = [] def generate_random_node(self): x = np.random.uniform(self.xlim[0], self.xlim[1]) y = np.random.uniform(self.ylim[0], self.ylim[1]) return np.array([x, y]) def find_nearest_node(self, point): distances = [np.linalg.norm(point - node) for node in self.nodes] nearest_node_index = np.argmin(distances) return nearest_node_index def is_collision_free(self, point): for obstacle in self.obstacles: if obstacle.contains_point(point): return False return True def steer(self, start, end): direction = end - start norm = np.linalg.norm(direction) if norm <= self.step_size: return end else: return start + direction * (self.step_size / norm) def generate_path(self): self.nodes.append(self.start) for _ in range(self.max_iter): random_node = self.generate_random_node() nearest_node_index = self.find_nearest_node(random_node) nearest_node = self.nodes[nearest_node_index] new_node = self.steer(nearest_node, random_node) if not self.is_collision_free(new_node): continue self.nodes.append(new_node) if np.linalg.norm(new_node - self.goal) < self.step_size: self.nodes.append(self.goal) break if len(self.nodes) > 0 and np.linalg.norm(self.nodes[-1] - self.goal) >= self.step_size: return None return self.nodes def plot(self): plt.figure() for obstacle in self.obstacles: plt.plot(*obstacle.exterior.xy, 'r-') if self.nodes is not None: path = np.array(self.nodes) plt.plot(path[:, 0], path[:, 1], 'b-') plt.plot(self.start[0], self.start[1], 'go') plt.plot(self.goal[0], self.goal[1], 'ro') plt.xlim(*self.xlim) plt.ylim(*self.ylim) plt.gca().set_aspect('equal', adjustable='box') plt.show() ``` 该示例中的`RRT`类实现RRT算法的基本逻辑。您可以通过设置起点,目标点,障碍物,以及环境的x和y边界来使用该算法。然后,通过调用`generate_path`方法生成路径。最后,通过调用`plot`方法可视化路径和环境。 请注意,该示例中的障碍物使用了`shapely`库来表示和检测碰撞。您可以根据自己的需求进行适当的修改和扩展。 希望这个示例对您有所帮助!

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值