LRTA*(Learning-RTA*)

本文介绍了LRTA*算法,它是对RTA*的优化,解决了单次问题求解中的局部最优问题。LRTA*通过存储最优解而非次优解,避免了价值膨胀。代码示例展示了如何实现LRTA*算法,尽管在给定的代码中效果有限,但主要讲解了算法原理和关键更新机制。
摘要由CSDN通过智能技术生成

1、基本概念

LRTA* 算法是对RTA* 算法的改进,在RTA* 的原论文中,提到了:

Unfortunately, while RTA* as described above is ideally suited to single problem solving trials, it must be modified to accommodate multi-trial learning. The reason is that the algorithm records the second best estimate in the previous state, which represents an accurate estimate of that state looking back from the perspective of the next state.However, if the best estimate turns out to be correct, then storing the second best value can result in inflated values for some states. These inflated values will direct the next agents in the wrong direction on subsequent problem solving trials.

即:

RTA* 算法记录了前一状态下的第二个最佳估计,这代表了从下一状态的角度回顾该状态的准确估计。然而,如果最佳估计被证明是正确的,那么存储第二个最好的值可能会导致某些状态的值被夸大。这些夸大的价值观将引导下一代在随后的问题解决试验中走向错误的方向。

LRTA* 通过记录最优解而不是第二最优解解决了这个问题。LRTA* 算法的核心是对节点的启发式值的更新,启发式值的更新使用启发式函数:
H ( s ) = g ( s , s ′ ) + H ( s ′ ) H(s)=g(s,s^′)+H(s^′) H(s)=g(s,s)+H(s)
其中H(s)表示前一节点的启发式值,g(s, s’)表示从节点s到s‘的代价,H(s’)表示当前节点的启发式值。LRTA*的完整算法如下:

在这里插入图片描述
h(s) 的设计,是为了防止之前A* 算法在搜索时陷入局部最小值,在LRTA* 搜索中,如果陷入了局部最小值,算法会根据访问附近节点的次数增加h(s),即增加总成本f(s),从而经过h(s)的多次叠加后,跳出局部最小值。

h(s) 的更新是这样的:

例如从 s 节点到 s’ 节点,假设路程g(s, s’) = 1,已知 h(s’) = 2(这个初始的时候就是已知的,代表了从该节点到目标节点的距离),那么从 s 节点访问 s’ 节点后,是对s节点的h(s)进行更新的,h(s) = g(s) + h(s’) = 1 + 2 = 3。

另外,如果想从 s’ 再回到 s节点的话,那么就要对 h(s’) 更新,h(s’) = g(s’, s) + h(s) = 1 + 3 = 4。

该算法的精髓就在 h(s) 的更新上,理解了 h(s) 的更新,LRTA* 算法就基本理解了。

具体例子可以看参考文档里面的第一篇,这里就不细述了。

2、代码示例:

import os
import sys
import math
import copy
import heapq
import matplotlib.pyplot as plt


class LRTAStar:
    """AStar set the cost + heuristics as the priority
    """
    def __init__(self, s_start, s_goal, heuristic_type,xI, xG):
        self.s_start = s_start
        self.s_goal = s_goal
        self.heuristic_type = heuristic_type

        self.u_set = [(-1, 0), (-1, 1), (0, 1), (1, 1),
                        (1, 0), (1, -1), (0, -1), (-1, -1)]  # feasible input set
        self.obs = self.obs_map()  # position of obstacles

        self.OPEN = dict()  # priority queue / OPEN set
        self.CLOSED = []  # CLOSED set / VISITED order
        self.PARENT = dict()  # recorded parent
        self.h = dict()
        self.g = dict()  # cost to come
        self.x_range = 51  # size of background
        self.y_range = 31
        
        self.xI, self.xG = xI, xG
        self.obs = self.obs_map()

    def update_obs(self, obs):
        self.obs = obs

    def animation(self, path, visited, name):
        self.plot_grid(name)
        self.plot_visited(visited)
        self.plot_path(path)
        plt.show()


    def plot_grid(self, name):
        obs_x = [x[0] for x in self.obs]
        obs_y = [x[1] for x in self.obs]

        plt.plot(self.xI[0], self.xI[1], "bs")
        plt.plot(self.xG[0], self.xG[1], "gs")
        plt.plot(obs_x, obs_y, "sk")
        plt.title(name)
        plt.axis("equal")

    def plot_visited(self, visited, cl='gray'):
        if self.xI in visited:
            visited.remove(self.xI)

        if self.xG in visited:
            visited.remove(self.xG)

        count = 0

        for x in visited:
            count += 1
            plt.plot(x[0], x[1], color=cl, marker='o')
            plt.gcf().canvas.mpl_connect('key_release_event',
                                         lambda event: [exit(0) if event.key == 'escape' else None])

            if count < len(visited) / 3:
                length = 20
            elif count < len(visited) * 2 / 3:
                length = 30
            else:
                length = 40
            #
            # length = 15

            if count % length == 0:
                plt.pause(0.001)
        plt.pause(0.01)

    def plot_path(self, path, cl='r', flag=False):
        path_x = [path[i][0] for i in range(len(path))]
        path_y = [path[i][1] for i in range(len(path))]

        if not flag:
            plt.plot(path_x, path_y, linewidth='3', color='r')
        else:
            plt.plot(path_x, path_y, linewidth='3', color=cl)

        plt.plot(self.xI[0], self.xI[1], "bs")
        plt.plot(self.xG[0], self.xG[1], "gs")

        plt.pause(0.01)


    def update_obs(self, obs):
        self.obs = obs

    def obs_map(self):
        """
        Initialize obstacles' positions
        :return: map of obstacles
        """

        x = 51
        y = 31
        obs = set()

        for i in range(x):
            obs.add((i, 0))
        for i in range(x):
            obs.add((i, y - 1))

        for i in range(y):
            obs.add((0, i))
        for i in range(y):
            obs.add((x - 1, i))

        for i in range(10, 21):
            obs.add((i, 15))
        for i in range(15):
            obs.add((20, i))

        for i in range(15, 30):
            obs.add((30, i))
        for i in range(16):
            obs.add((40, i))

        return obs
    def searching(self):
        """
        A_star Searching.
        :return: path, visited order
        """

        self.PARENT[self.s_start] = self.s_start
        self.g[self.s_start] = 0
        self.g[self.s_goal] = math.inf
        self.OPEN[self.s_start] = self.f_value(self.s_start,self.s_start)
        count = 1
        while self.OPEN:
            s = min(self.OPEN, key=self.OPEN.get)
            print(count)
            self.OPEN.pop(s)
            self.CLOSED.append(s)
            
            count += 1
            if s == self.s_goal:  # stop condition
                break
            new_h = math.inf
            for s_n in self.get_neighbor(s):
                new_cost = self.g[s] + self.cost(s, s_n)
                new_s_n = self.heuristic2(s,s_n) + self.heuristic(s_n)
                if new_s_n < new_h:
                    new_h = copy.deepcopy(new_s_n);
                if s_n not in self.g or new_cost < self.g[s_n]:  
                    self.g[s_n] = new_cost
                    self.PARENT[s_n] = s
                    self.OPEN[s_n] = self.f_value(s,s_n)
            if new_s_n > self.h[s]:
                self.h[s] = new_s_n
        #print(self.CLOSED)
        return self.extract_path(self.PARENT), self.CLOSED

    def get_neighbor(self, s):
        """
        find neighbors of state s that not in obstacles.
        :param s: state
        :return: neighbors
        """

        return [(s[0] + u[0], s[1] + u[1]) for u in self.u_set]

    def cost(self, s_start, s_goal):
        """
        Calculate Cost for this motion
        :param s_start: starting node
        :param s_goal: end node
        :return:  Cost for this motion
        :note: Cost function could be more complicate!
        """

        if self.is_collision(s_start, s_goal):
            return math.inf

        return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])

    def is_collision(self, s_start, s_end):
        """
        check if the line segment (s_start, s_end) is collision.
        :param s_start: start node
        :param s_end: end node
        :return: True: is collision / False: not collision
        """

        if s_start in self.obs or s_end in self.obs:
            return True

        if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
            if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
                s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
                s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
            else:
                s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
                s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))

            if s1 in self.obs or s2 in self.obs:
                return True

        return False

    def f_value(self,s, s_n):
        """
        f = g + h. (g: Cost to come, h: heuristic value)
        :param s: current state
        :return: f
        """
        if s_n in self.h:
            return self.g[s_n] + self.h[s_n]
        else:
            self.h[s_n] = self.heuristic(s_n)
        return self.g[s_n] + self.heuristic(s_n)

    def extract_path(self, PARENT):
        """
        Extract the path based on the PARENT set.
        :return: The planning path
        """

        path = [self.s_goal]
        s = self.s_goal

        while True:
            s = PARENT[s]
            path.append(s)

            if s == self.s_start:
                break

        return list(path)

    def heuristic2(self, s1,s2):
        heuristic_type = self.heuristic_type  # heuristic type
        goal = self.s_goal  # goal node

        if heuristic_type == "manhattan":
            return abs(s1[0] - s2[0]) + abs(s1[1] - s2[1])
        else:#sqrt(x^2+y^2)
            return math.hypot(s1[0] - s2[0], s1[1] - s2[1])

    def heuristic(self, s):
        """
        Calculate heuristic.
        :param s: current node (state)
        :return: heuristic function value
        """

        heuristic_type = self.heuristic_type  # heuristic type
        goal = self.s_goal  # goal node

        if heuristic_type == "manhattan":
            return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
        else:#sqrt(x^2+y^2)
            return math.hypot(goal[0] - s[0], goal[1] - s[1])


def main():
    s_start = (5, 5)
    s_goal = (45, 25)

    astar = LRTAStar(s_start, s_goal, "euclidean",s_start,s_goal)

    path, visited = astar.searching()
    astar.animation(path, visited, "LRTA*")  # animation



if __name__ == '__main__':
    main()

上述代码运行后执行的循环次数是和A* 的那个是一样的,这里有点奇怪,不知道是我理解的有问题还是确实它起到的作用比较有限,因为LRTA* 的作用主要是快速的跳出局部最优解的问题,但是这里可能没有出现这方面的问题所以其实也就没有起到优化的效果,总的来说,相当于了解一个算法思路,但是作用似乎很有限。

参考:
1、《[PR] LRTA* 搜索算法
2、《[AI] LRTA*搜索算法及其扩展算法

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一叶执念

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值