强化学习实战之SARSA算法

SARSA

SARSA算法是TD Learning中的一个重要应用,TD算法是根据一个直接后继状态节点的单次样本转移来更新的,它没有使用完整的一幕,并且它采用了自举法,因为它用了后继状态的估计值来更新当前状态的估计值。

我们导入必要的包并熟悉以下“出租车调度”的环境。

import gym
import collections
import itertools
import matplotlib.pyplot as plt
import matplotlib
matplotlib.style.use('ggplot')
import numpy as np
import pandas as pd
import time
from IPython.display import clear_output
env = gym.make('Taxi-v3')
print(env.action_space)
print(env.observation_space)
state = env.reset()
env.render()
x = env.decode(state)
for i in x:
    print(i)

在这里插入图片描述

该环境中动作有六个分别是:上、下、左、右、请乘客上车、请乘客下车。黄色方块表示没有乘客的汽车,如果有乘客则是绿色。汽车只能在虚线部分改变方向。R、G、Y、B是四个站点,代号分别是0、1、2、3蓝色代表乘客所在位置,红色代表乘客目的地。每个状态由一个元组:(taxirow, taxicol, passloc, destidx)表示,前两个表示了出租车的坐标如图所示在(3,1)处,第三、第四个元素分别是乘客所在位置和目的地的代号。因此共有 ( 5 × 5 ) × 5 × 4 = 500 (5 \times 5)\times 5 \times 4=500 (5×5)×5×4=500 个状态。每试图移动一次的收益是-1,错误地让乘客下车或上车收益是-10,顺利地完成一次任务收益是20,直到完成任务或者200步后还没能完成任务一幕就结束。

接下来定义一个agent类:

class SARSAagent:
    def __init__(self, env, gamma=0.9, learning_rate=0.1, epsilon=0.1):
        self.gamma = gamma
        self.learning_rate = learning_rate
        self.epsilon = epsilon
        self.action_n = env.action_space.n
        self.q_table = np.zeros((env.observation_space.n, env.action_space.n))
        
        
    def use_epsilon_greedy_policy(self, state):
        if np.random.uniform() > self.epsilon:
            action = np.argmax(self.q_table[state])
        else:
            action = np.random.randint(self.action_n)
        return action
    
    
    def learn(self, state, action, reward, next_state, next_action, done):
        td_target = reward + self.gamma * self.q_table[next_state][next_action] * (1. - done)
        td_error = td_target - self.q_table[state][action]
        self.q_table[state][action] += self.learning_rate * td_error

这里的 ϵ − g r e e d y   p o l i c y \epsilon-greedy \ policy ϵgreedy policy 和MC算法实战中的定义有所区别,MC中是根据 ϵ − g r e e d y   p o l i c y \epsilon-greedy \ policy ϵgreedy policy 本质上是一个分段函数来定义的,这里则是直接从定义出发,即在(0,1)按均匀分布随机生成一个数,如果它比 ϵ \epsilon ϵ 大则选择最优动作(相当于是 $1-\epsilon $ 的概率)否则就随便选一个动作。learn方法则是完全遵照了SARSA的更新定义。

接下来就结合着SARSA的算法来定义execute_SARSA_one_episode

def execute_SARSA_one_episode(env, agnet, render = False):
    total_steps, total_rewards = 0.0, 0.0
    state = env.reset()
    action = agent.use_epsilon_greedy_policy(state)
    while True:
        if render:
            env.render()
            clear_output(wait=True)
            time.sleep(0.02)
        next_state, reward, done, _ = env.step(action)
        total_steps += 1.
        total_rewards += reward
        next_action = agent.use_epsilon_greedy_policy(next_state)
        agent.learn(state, action, reward, next_state, next_action, done)
        if done:
            if render:
                clear_output(wait=True)
                print('END')
                print('total_steps: ', total_steps)
                time.sleep(3)
            break
        else:
            state, action = next_state, next_action
    return total_steps, total_rewards

在这里插入图片描述

在这里插入图片描述

每次对agent类的初始化中将Q_table中的每一个值都定义成0,在一幕中对每个“状态—动作”对按照SARSA算法进行一次更新,进行5000幕那么每个"状态—动作"对都将收敛。

这里用到了一个unzip的小技巧: list(zip(*result))!

agent = SARSAagent(env)
result = [execute_SARSA_one_episode(env, agent) for _ in range(5000)]
unziped_resutl = list(zip(*result))
steps = list(unziped_resutl[0])
rewards = list(unziped_resutl[1])

然后利用pandas中的rolling平滑曲线并将其绘制出来

steps_smoothed = pd.Series(steps).rolling(20, min_periods=20).mean()
plt.figure(figsize=(15, 8))
plt.title('steps of each episode', fontsize=25, color='r')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.plot(steps_smoothed, color='b')
plt.savefig('SARSA_steps_of_each_episode.png')

在这里插入图片描述

可以看到随着训练的幕逐步增多,agent从一开始走完200步都不能完成任务到最后基本上走二十几步就能完成。

收益曲线也相应的逐幕收敛了。

rewards_smoothed = pd.Series(rewards).rolling(20,20).mean()
plt.figure(figsize=(15, 8))
plt.title('rewards of each episode', fontsize=25, color='r')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.plot(rewards_smoothed, color='b')
plt.savefig('SARSA_rewards_of_each_episode.png')

在这里插入图片描述

最后打印一下Q_table和policy

pd.DataFrame(agent.q_table)
012345
00.0000000.0000000.0000000.0000000.0000000.000000
1-3.985126-3.424446-4.072099-3.732169-1.682856-8.041540
2-1.4669430.1171210.1194860.6563574.190408-4.835304
3-3.140358-2.997036-3.705371-2.777566-0.355685-8.073674
4-6.745418-6.919960-6.771730-6.735160-9.493358-10.074422
.....................
4950.0000000.0000000.0000000.0000000.0000000.000000
496-2.881833-2.833192-2.878966-2.271016-3.764418-3.745854
497-1.783886-1.318232-1.769246-0.026576-1.909000-2.861079
498-3.281427-3.167514-3.254752-2.458533-4.899401-4.452290
499-0.190000-0.1990000.75788214.562293-2.731714-1.000000

500 rows × 6 columns

policy = np.eye(agent.action_n)[agent.q_table.argmax(axis=-1)]
policy
array([[1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1., 0.],
       ...,
       [0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0., 0.]])

参考资料

《强化学习原理与Python实现》肖智清

github参考代码

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值