【莫烦强化学习】视频笔记(三)2.SARSA学习实现走迷宫

第8节 SARSA学习实现走迷宫


之前一篇文章已经介绍过Q学习实现走迷宫的程序编写,对Q学习的整个过程也有了更加深刻的了解,文章链接:【莫烦强化学习】视频笔记(二)3.Q_Learning算法实现走迷宫
这里只介绍与Q学习不同的(需要修改的)代码部分,最后给出整个源代码,如有错误请各位批评指正,感谢~


8.1 SARSA-Learning类

之前介绍的Q-Learning类,有初始化、选择动作、学习更新参数、查看状态是否存在四个模块,其中初始化(全局参数)、选择动作、查看状态是否存在这几个函数部分的功能,Q学习与SARSA学习是没有区别的,区别在于学习更新参数这个模块。

    def learn(self, s, a, r, s_, a_):
        self.check_state_exit(s_)  # 查看状态s_是否存在,s_是在选择动作之后与环境交互获得的下一状态
        q_predict = self.q_table.loc[s, a]  # 当前状态s和动作a对应的Q值
        if s_ != 'terminal':  # 若下一步不是终态
            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # 下一动作已经采样得到,直接使用s'与a'的Q值即可
        else:
            q_target = r;  # 否则直接为立即回报
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # 更新Q(s,a)

当然,也可以使用继承类的方式,继承之前写好的Qlearning类。


8.2 主循环

应当注意的是,这里的环境并不需要改变,环境是固定的,只能够对当前动作进行反馈。所以还是之前所介绍的maze_env环境代码,maze_env环境请参考莫烦的GITHUB
主循环相比于Q学习有一些变化,在更新的附近,按照之前与Q学习的不同和SARSA学习介绍的阐述,伪代码如下:
在这里插入图片描述

def update():  # 更新主函数
    for episode in range(100):  # 玩游戏的局数
        observation = env.reset()  # 初始化环境
        action = RL.choose_action(str(observation))
        while True:
            env.render()  # 刷新图像
            observation_, reward, done = env.step(action)  # 动作与环境交互,获得下一状态、奖励值和是否为终态的反馈
            action_ = RL.choose_action(str(observation_))  # 直接通过ε-greedy获得下一个动作a'
            RL.learn(str(observation), action, reward, str(observation_), action_)  # 更新Q表
            observation = observation_  # 转移到下一状态
            action = action_  # 动作直接就是刚才的动作
            if done:
                break
    print('Game Over')  # 游戏结束
    env.destroy()  # 关闭窗口

8.3 全代码一览
主循环 main.py
from SARSAlearning import SARSALearning
from maze_env import Maze


def update():  # 更新主函数
    for episode in range(100):  # 玩游戏的局数
        observation = env.reset()  # 初始化环境
        action = RL.choose_action(str(observation))
        while True:
            env.render()  # 刷新图像
            observation_, reward, done = env.step(action)  # 动作与环境交互,获得下一状态、奖励值和是否为终态的反馈
            action_ = RL.choose_action(str(observation_))  # 直接通过ε-greedy获得下一个动作a'
            RL.learn(str(observation), action, reward, str(observation_), action_)  # 更新Q表
            observation = observation_  # 转移到下一状态
            action = action_  # 动作直接就是刚才的动作
            if done:
                break
    print('Game Over')  # 游戏结束
    env.destroy()  # 关闭窗口


if __name__ == '__main__':
    env = Maze()  # 创建环境
    RL = SARSALearning(actions=list(range(env.n_actions)))  # Q学习类

    env.after(100, update)  # 100ms后调用函数
    env.mainloop()  # 开始可视化环境
SARSALearning类 SARSALearning.py
import numpy as np
import pandas as pd


class SARSALearning:
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):  # 初始换函数,后面是默认参数
        self.actions = actions  # 动作空间
        self.lr = learning_rate  # 学习率
        self.gamma = reward_decay  # 奖励衰减
        self.epsilon = e_greedy  # 贪婪度
        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)  # 初始Q表

    def check_state_exit(self, state):  # 输入状态
        if state not in self.q_table.index:  # Q表中没有该状态
            # 插入新的行,Q值初始化为0
            self.q_table = self.q_table.append(
                pd.Series([0] * len(self.actions), index=self.q_table.columns, name=state))

    def choose_action(self, observation):  # 根据当前的状态选择动作
        self.check_state_exit(observation)  # 检查状态是否存在,不存在添加到Q表中
        if np.random.uniform() < self.epsilon:  # 直接选择Q值最大的动作
            state_action = self.q_table.loc[observation, :]  # 选择对应的一行
            # 由于Q值最大的动作也有可能有多个,我们需要对这些动作随机选择(乱序)
            action = np.random.choice(state_action[state_action == np.max(state_action)].index)
        else:
            action = np.random.choice(self.actions)  # 随机选择一个动作
        return action

    def learn(self, s, a, r, s_, a_):
        self.check_state_exit(s_)  # 查看状态s_是否存在,s_是在选择动作之后与环境交互获得的下一状态
        q_predict = self.q_table.loc[s, a]  # 当前状态s和动作a对应的Q值
        if s_ != 'terminal':  # 若下一步不是终态
            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # 下一动作已经采样得到,直接使用s'与a'的Q值即可
        else:
            q_target = r;  # 否则直接为立即回报
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # 更新Q(s,a)
环境代码 maze_env.py
import numpy as np
import time
import sys
if sys.version_info.major == 2:
    import Tkinter as tk
else:
    import tkinter as tk


UNIT = 40   # pixels
MAZE_H = 4  # grid height
MAZE_W = 4  # grid width


class Maze(tk.Tk, object):
    def __init__(self):
        super(Maze, self).__init__()
        self.action_space = ['u', 'd', 'l', 'r']
        self.n_actions = len(self.action_space)
        self.title('maze')
        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT))
        self._build_maze()

    def _build_maze(self):
        self.canvas = tk.Canvas(self, bg='white',
                           height=MAZE_H * UNIT,
                           width=MAZE_W * UNIT)

        # create grids
        for c in range(0, MAZE_W * UNIT, UNIT):
            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for r in range(0, MAZE_H * UNIT, UNIT):
            x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)

        # create origin
        origin = np.array([20, 20])

        # hell
        hell1_center = origin + np.array([UNIT * 2, UNIT])
        self.hell1 = self.canvas.create_rectangle(
            hell1_center[0] - 15, hell1_center[1] - 15,
            hell1_center[0] + 15, hell1_center[1] + 15,
            fill='black')
        # hell
        hell2_center = origin + np.array([UNIT, UNIT * 2])
        self.hell2 = self.canvas.create_rectangle(
            hell2_center[0] - 15, hell2_center[1] - 15,
            hell2_center[0] + 15, hell2_center[1] + 15,
            fill='black')

        # create oval
        oval_center = origin + UNIT * 2
        self.oval = self.canvas.create_oval(
            oval_center[0] - 15, oval_center[1] - 15,
            oval_center[0] + 15, oval_center[1] + 15,
            fill='yellow')

        # create red rect
        self.rect = self.canvas.create_rectangle(
            origin[0] - 15, origin[1] - 15,
            origin[0] + 15, origin[1] + 15,
            fill='red')

        # pack all
        self.canvas.pack()

    def reset(self):
        self.update()
        time.sleep(0.5)
        self.canvas.delete(self.rect)
        origin = np.array([20, 20])
        self.rect = self.canvas.create_rectangle(
            origin[0] - 15, origin[1] - 15,
            origin[0] + 15, origin[1] + 15,
            fill='red')
        # return observation
        return self.canvas.coords(self.rect)

    def step(self, action):
        s = self.canvas.coords(self.rect)
        base_action = np.array([0, 0])
        if action == 0:   # up
            if s[1] > UNIT:
                base_action[1] -= UNIT
        elif action == 1:   # down
            if s[1] < (MAZE_H - 1) * UNIT:
                base_action[1] += UNIT
        elif action == 2:   # right
            if s[0] < (MAZE_W - 1) * UNIT:
                base_action[0] += UNIT
        elif action == 3:   # left
            if s[0] > UNIT:
                base_action[0] -= UNIT

        self.canvas.move(self.rect, base_action[0], base_action[1])  # move agent

        s_ = self.canvas.coords(self.rect)  # next state

        # reward function
        if s_ == self.canvas.coords(self.oval):
            reward = 1
            done = True
            s_ = 'terminal'
        elif s_ in [self.canvas.coords(self.hell1), self.canvas.coords(self.hell2)]:
            reward = -1
            done = True
            s_ = 'terminal'
        else:
            reward = 0
            done = False

        return s_, reward, done

    def render(self):
        time.sleep(0.1)
        self.update()


def update():
    for t in range(10):
        s = env.reset()
        while True:
            env.render()
            a = 1
            s, r, done = env.step(a)
            if done:
                break

if __name__ == '__main__':
    env = Maze()
    env.after(100, update)
    env.mainloop()

上一篇:【莫烦强化学习】视频笔记(三)1.什么是SARSA?
下一篇:【莫烦强化学习】视频笔记(三)3.SARSA(lambda)

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
基于价值的强化学习问题可以使用以下算法进行解决: A. PPO算法 B. SARSA算法 C. DQN算法 D. 策略梯度算法 这些算法都是用于解决强化学习问题的,但是它们在解决问题的方式和原理上有所不同。以下是对每个算法的简要介绍: A. PPO算法(Proximal Policy Optimization)是一种基于策略梯度的算法,它通过优化策略函数来提高智能体的性能。PPO算法通过使用一种称为“重要性采样”的技术来更新策略函数,以平衡探索和利用的权衡。 B. SARSA算法(State-Action-Reward-State-Action)是一种基于值函数的算法,它通过估计每个状态-动作对的值来指导智能体的决策。SARSA算法使用一种称为“时序差分学习”的技术来更新值函数,以逐步改进智能体的策略。 C. DQN算法(Deep Q-Network)是一种基于值函数的算法,它使用深度神经网络来估计状态-动作对的值函数。DQN算法通过使用一种称为“经验回放”的技术来训练神经网络,并使用一种称为“ε-贪婪策略”的技术来指导智能体的决策。 D. 策略梯度算法是一类基于策略梯度的算法,它通过直接优化策略函数来提高智能体的性能。策略梯度算法使用一种称为“策略梯度定理”的技术来更新策略函数,以最大化期望回报。 综上所述,以上四种算法都可以用于解决基于价值的强化学习问题,但它们在解决问题的方式和原理上有所不同。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值