Sarsa增强版之Sarsa-λ依然走迷宫

Sarsa-λ(Sarsa Lambda)是Sarsa算法的一种变体,其中“λ”表示一个介于0和1之间的参数,用于平衡当前状态和之前所有状态的重要性。

Sarsa算法是一种基于Q-learning算法的增量式学习方法,通过在实际环境中不断探索和学习,逐渐更新策略函数和价值函数,以实现最优行为策略的学习。

Sarsa-λ算法在Sarsa算法的基础上引入了一个新的概念,即“λ衰减”,用于平衡当前状态和之前所有状态的重要性。在Sarsa-λ算法中,我们不仅考虑当前状态的奖励和下一个状态的Q值,还考虑了之前所有状态的Q值,并使用“λ衰减”参数来平衡它们的重要性。这样可以使得学习更具有长远的远见,可以对之前的行动进行更好的学习和回溯。

相比之下,Sarsa算法只考虑当前状态和下一个状态的Q值,不考虑之前所有状态的Q值,因此学习过程不够长远和细致。

总的来说,Sarsa-λ算法比Sarsa算法更适合在具有长时间依赖关系的任务中使用,能够更好地处理延迟奖励问题,同时也更加复杂和计算密集。

话不多说,来看代码上有什么不同:
首先是environment

import numpy as np
import time
import tkinter as tk

#定义一些常量
UNIT=40
WIDTH=4
HIGHT=4

class Palace(tk.Tk,object):
    def __init__(self):
        super(Palace, self).__init__()
        # 动作空间
        self.action_space = ['u', 'd', 'l', 'r']
        # self.n_action=len(self.action_space)
        self.title('maze')
        # 建立画布
        self.geometry('{0}x{1}'.format(HIGHT * UNIT, WIDTH * UNIT))
        self.build_maze()
    def build_maze(self):
        self.canvas = tk.Canvas(self, bg='white', height=HIGHT * UNIT, width=WIDTH * UNIT)
        # 绘制线框
        for i in range(0, WIDTH * UNIT, UNIT):
            x0, y0, x1, y1 = i, 0, i, WIDTH * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for j in range(0, HIGHT * UNIT, UNIT):
            x0, y0, x1, y1 = 0, j, HIGHT * UNIT, j
            self.canvas.create_line(x0, y0, x1, y1)

        # 创建迷宫中的地狱
        hell_center1 = np.array([100, 20])
        self.hell1 = self.canvas.create_rectangle(hell_center1[0] - 15, hell_center1[1] - 15, hell_center1[0] + 15,
                                                  hell_center1[1] + 15, fill='black')
        hell_center2 = np.array([20, 100])
        self.hell2 = self.canvas.create_rectangle(hell_center2[0] - 15, hell_center2[1] - 15, hell_center2[0] + 15,
                                                  hell_center2[1] + 15, fill='green')

        # 创建出口
        out_center = np.array([100, 100])
        self.oval = self.canvas.create_oval(out_center[0] - 15, out_center[1] - 15, out_center[0] + 15,
                                            out_center[1] + 15, fill='yellow')

        # 智能体
        origin = np.array([20, 20])
        self.finder = self.canvas.create_rectangle(origin[0] - 15, origin[1] - 15, origin[0] + 15, origin[1] + 15,
                                                   fill='red')

        self.canvas.pack()  # 一定不要忘记加括号

        # 智能体探索步
        def step(self, action):
            s = self.canvas.coords(self.finder)  # 获取智能体当前的位置
            # 由于移动的函数需要传递移动大小的参数,所以这里需要定义一个移动的基准距离
            base_action = np.array([0, 0])
            # 根据action来确定移动方向
            if action == 'u':
                if s[1] > UNIT:
                    base_action[1] -= UNIT
            elif action == 'd':
                if s[1] < HIGHT * UNIT:
                    base_action[1] += UNIT
            elif action == 'l':
                if s[0] > UNIT:
                    base_action[0] -= UNIT
            elif action == 'r':
                if s[0] < WIDTH * UNIT:
                    base_action[0] += UNIT

            # 移动
            self.canvas.move(self.finder, base_action[0], base_action[1])
            # 移动后记录新位置指标
            s_ = self.canvas.coords(self.finder)

            # 反馈奖励,terminal不是自己赋予的,而是判断出来的
            if s_ == self.canvas.coords(self.oval):
                reward = 1
                done = True
                s_ = 'terminal'  # 结束了
            elif s_ in (self.canvas.coords(self.hell2), self.canvas.coords(self.hell1)):
                reward = -1
                done = True
                s_ = 'terminal'
            else:
                reward = 0
                done = False
            # 这个学习函数不但传入的参数多,返回的结果也多
            return s_, reward, done

        def reset(self):
            self.update()
            time.sleep(0.5)
            self.canvas.delete(self.rect)
            origin = np.array([20, 20])
            self.rect = self.canvas.create_rectangle(
                origin[0] - 15, origin[1] - 15,
                origin[0] + 15, origin[1] + 15,
                fill='red')
            # return observation
            return self.canvas.coords(self.rect)

        def render(self):
            time.sleep(0.05)
            self.update()

environment没什么变化,接下来是智能体agent

"""
This part of code is the Q learning brain, which is a brain of the agent.
All decisions are made in here.

View more on my tutorial page: https://morvanzhou.github.io/tutorials/
"""

import numpy as np
import pandas as pd


class RL(object):
    def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        self.actions = action_space  # a list
        self.lr = learning_rate
        self.gamma = reward_decay
        self.epsilon = e_greedy

        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # append new state to q table
            self.q_table = self.q_table.append(
                pd.Series(
                    [0] * len(self.actions),
                    index=self.q_table.columns,
                    name=state,
                )
            )

    def choose_action(self, observation):
        self.check_state_exist(observation)
        # action selection
        if np.random.rand() < self.epsilon:
            # choose best action
            state_action = self.q_table.loc[observation, :]
            # some actions may have the same value, randomly choose on in these actions
            action = np.random.choice(state_action[state_action == np.max(state_action)].index)
        else:
            # choose random action
            action = np.random.choice(self.actions)
        return action

    def learn(self, *args):
        pass


# backward eligibility traces
class SarsaLambdaTable(RL):
    # 注意,这里多了一个参数,trace_decay,步伐的衰减值,和奖励的衰减值类似,都是让离奖励越远的值影响越小
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):
        super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

        # backward view, eligibility trace.
        # 这里出现了lamba,其实它是干什么的我还不清楚,
        self.lambda_ = trace_decay
        # 拷贝,把q_table拷贝了一份
        self.eligibility_trace = self.q_table.copy()

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # append new state to q table
            to_be_append = pd.Series(
                [0] * len(self.actions),
                index=self.q_table.columns,
                name=state,
            )
            self.q_table = self.q_table.append(to_be_append)

            # also update eligibility trace
            # 这份拷贝的表是和原表同步更新的
            self.eligibility_trace = self.eligibility_trace.append(to_be_append)

    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        # 先检查状态,不在表中就添加
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            # 这是现实,q_target就是现实
            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # next state is not terminal
        else:
            q_target = r  # next state is terminal
        # 不直接更新,而是把误差计算出来,留着后面使用
        error = q_target - q_predict

        # increase trace amount for visited state-action pair
        # 这个lamba主要就是一个更新规则一起就是单步更新,但是那样效率有点慢,
        # eligiblity_trace就是做一个步伐轨迹的记录

        # Method 1:
        # self.eligibility_trace.loc[s, a] += 1

        # Method 2:
        self.eligibility_trace.loc[s, :] *= 0
        self.eligibility_trace.loc[s, a] = 1

        # Q update
        self.q_table += self.lr * error * self.eligibility_trace

        # decay eligibility trace after update
        self.eligibility_trace *= self.gamma * self.lambda_
        return self.q_table

在强化学习中,Eligibility通常指的是某个状态-动作对(State-Action Pair)对价值函数的贡献。具体来说,它描述了某个状态-动作对对价值函数的影响程度,可以用于增量式地更新价值函数。

Eligibility一般被用于Sarsa-Lambda等强化学习算法中。在这些算法中,每个状态-动作对都会维护一个相关的Eligibility值,表示该状态-动作对对当前的价值函数有多大的贡献。每次更新价值函数时,Eligibility值会被相应地更新。

通常情况下,Eligibility值会根据时间衰减,即先前的状态-动作对对价值函数的贡献会随着时间的推移而逐渐减少,而当前状态-动作对对价值函数的贡献会更高。具体来说,Sarsa-Lambda等算法会使用一个衰减参数来控制Eligibility值的衰减速度,从而平衡过去和现在的状态-动作对对价值函数的贡献。

然后运行run

"""
Sarsa is a online updating method for Reinforcement learning.

Unlike Q learning which is a offline updating method, Sarsa is updating while in the current trajectory.

You will see the sarsa is more coward when punishment is close because it cares about all behaviours,
while q learning is more brave because it only cares about maximum behaviour.
"""

from maze_env import Maze
from RL_brain import SarsaLambdaTable


def update():
    for episode in range(10):
        # initial observation
        observation = env.reset()

        # RL choose action based on observation
        action = RL.choose_action(str(observation))

        # initial all zero eligibility trace,每跑一次都置零,哎不管了,直接干
        RL.eligibility_trace *= 0

        while True:
            # fresh env
            env.render()

            # RL take action and get next observation and reward
            observation_, reward, done = env.step(action)

            # RL choose action based on next observation
            action_ = RL.choose_action(str(observation_))

            # RL learn from this transition (s, a, r, s, a) ==> Sarsa
            q_table = RL.learn(str(observation), action, reward, str(observation_), action_)

            # swap observation and action
            observation = observation_
            action = action_

            # break while loop when end of this episode
            if done:
                break

    # end of game
    print('game over')
    print(q_table)
    q_table.to_csv('output.csv')

    env.destroy()


if __name__ == "__main__":
    env = Maze()
    RL = SarsaLambdaTable(actions=list(range(env.n_actions)))

    env.after(10, update)
    env.mainloop()

不知道是怎么回事,Sarsa-lambda的效果有时好于Sarsa,并不十分稳定,后面再继续研究研究

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值