Sarsa算法讲解及实现

Sarsa算法讲解及实现

1. Q表格

我们使用表格来存储每一个状态 state, 和在这个 state 每个行为 action 所拥有的 Q 值。

Q即为Q(s,a)就是在某一时刻的 s 状态下(s∈S),采取动作a (a∈A)动作能够获得收益的期望,环境会根据agent的动作反馈相应的回报reward r,所以算法的主要思想就是将State与Action构建成一张Q-table来存储Q值,然后根据Q值来选取能够获得最大的收益的动作。

例子:

Q-Tablea1a2
s1q(s1,a1)q(s1,a2)
s2q(s2,a1)q(s2,a2)
s3q(s3,a1)q(s3,a2)

2. Sarsa算法讲解

在强化学习中,Sarsa是一种对Q表格进行更新的算法,由于在强化学习环境最开始的时候,也可以认为是游戏刚开始的时候,Q表格是随机初始化的,所以需要在智能体不断与环境进行交互的时候不断地更新Q表格。

Sarsa表示的是State-Action-Reward-State-Action,是一个学习马尔可夫决策过程策略的算法,通常应用于机器学习和强化学习学习领域中。

State-Action-Reward-State-Action:这个名称清楚地反应了其学习更新函数依赖的5个值:分别是当前状态S1,当前状态选中的动作A1,获得的奖励Reward,S1状态下执行A1后取得的状态S2及S2状态下将会执行的动作A2。我们取这5个值的首字母串起来可以得出一个词SARSA。

Sarsa算法的更新公式:

Q ( s t , a t ) ← Q ( s t , a t ) + α [ r t + γ Q ( s t + 1 , a t + 1 ) − Q ( s t , a t ) ] Q(s_{\boldsymbol{t}},a_{t})\leftarrow Q(s_{\boldsymbol{t}},a_{t})+\alpha [r_{\boldsymbol{t}}+\gamma Q(s_{t+1},a_{\boldsymbol{t}+1})-Q(s_{\boldsymbol{t}},a_{t})] Q(st,at)Q(st,at)+α[rt+γQ(st+1,at+1)Q(st,at)]

Sarsa算法伪代码:
在这里插入图片描述

算法中各个参数的意义:

    1. alpha是学习率, 来决定这次的误差有多少是要被学习的, alpha是一个小于1 的数.
    1. gamma 是对未来 reward 的衰减值. 我们可以这样想象.
    1. Q表示的是Q表格.
    1. Epsilon greedy 是用在决策上的一种策略, 比如 epsilon = 0.9 时, 就说明有90% 的情况我会按照 Q 表的最优值选择行为, 10% 的时间使用随机选行为. 【这也是结合了强化学习中探索和利用的概念】

3. 代码


# agent.py

import numpy as np


class SarsaAgent(object):
    def __init__(self,
                 obs_n,
                 act_n,
                 learning_rate=0.01,
                 gamma=0.9,
                 e_greedy=0.1):
        self.obs_n = obs_n  # 状态维度
        self.act_n = act_n  # 动作维度
        self.learning_rate = learning_rate  # 学习率
        self.gamma = gamma  # 奖励衰减率
        self.e_greedy = e_greedy  # 按一定概率随机选动作
        self.Q = np.zeros((obs_n, act_n))  # Q表格 todo:嵌套一层有什么作用?

    def sample(self, obs):
        if np.random.sample() < (1 - self.e_greedy):  # 强化概念 #根据table的Q值选动作
            return self.predict(obs)
        else:
            # 随机选择一个
            return np.random.choice(self.act_n)

    def predict(self, obs):
        # 进行预测,直接选择Q值最高的那个动作

        # 拉出该状态的那一行动作
        Q_list = self.Q[obs]

        maxQ = np.max(Q_list)
        action_list = np.where(Q_list == maxQ)[0]

        # 防止有多个最大值,所以随机选择一个
        return np.random.choice(action_list)

    def learn(self, obs, act, reward, obs_next, act_next, done):
        """
            obs: 交互前的obs, s_t
            action: 本次交互选择的action, a_t
            reward: 本次动作获得的奖励r
            next_obs: 本次交互后的obs, s_t+1
            next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1
            done: episode是否结束
        """
        # s a r s a

        if done:
            target = reward
        else:
            target = reward + self.gamma * self.Q[obs_next][act_next]

        self.Q[obs][act] += self.learning_rate * (target - self.Q[obs][act])

    def save(self):
        npy_file = './q_table.npy'
        np.save(npy_file, self.Q)
        print(npy_file + ' saved.')

    def restore(self, npy_file='./q_table.npy'):
        self.Q = np.load(npy_file)
        print(npy_file + ' loaded.')

# train.py

import gym
from agent import SarsaAgent
import time

"""
    相当于是跑一个回合呗
"""


def run_episode(env, agent, render=False):
    total_steps = 0
    total_reward = 0

    obs = env.reset()
    action = agent.sample(obs)

    while True:
        # 采取动作,获取环境的反馈值
        next_obs, reward, done, _ = env.step(action)
        next_action = agent.sample(next_obs)

        # 训练sarsa算法
        agent.learn(obs, action, reward, next_obs, next_action, done)

        action = next_action
        obs = next_obs

        total_reward += reward
        total_steps += 1

        if render:
            env.render()  # 渲染新的一帧图形
        if done:
            break

    return total_reward, total_steps


"""
    用于在训练好Q表格后对其进行测试
"""


def test_episode(env, agent):
    total_reward = 0
    obs = env.reset()
    while True:
        action = agent.predict(obs)
        next_obs, reward, done, _ = env.step(action)
        total_reward += reward

        obs = next_obs
        time.sleep(0.5)
        env.render()
        if done:
            print('test reward = %.1f' % total_reward)
            break


def main():
    env = gym.make("CliffWalking-v0")  # 0 up, 1 right, 2 down, 3 left

    agent = SarsaAgent(obs_n=env.observation_space.n,
                       act_n=env.action_space.n,
                       learning_rate=0.1,
                       gamma=0.9,
                       e_greedy=0.1)

    is_render = False

    for episode in range(500):
        ep_reward, ep_steps = run_episode(env, agent, is_render)
        # 每隔20个episode渲染一下看看效果
        if episode % 20 == 0:
            print('it is in ' + str(episode) + 'round')
            is_render = True
        else:
            is_render = False
        # 训练结束,查看算法效果
    test_episode(env, agent)


if __name__ == '__main__':
    main()

  • 3
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值