SARSA
SARSA算法是TD Learning中的一个重要应用,TD算法是根据一个直接后继状态节点的单次样本转移来更新的,它没有使用完整的一幕,并且它采用了自举法,因为它用了后继状态的估计值来更新当前状态的估计值。
我们导入必要的包并熟悉以下“出租车调度”的环境。
import gym
import collections
import itertools
import matplotlib.pyplot as plt
import matplotlib
matplotlib.style.use('ggplot')
import numpy as np
import pandas as pd
import time
from IPython.display import clear_output
env = gym.make('Taxi-v3')
print(env.action_space)
print(env.observation_space)
state = env.reset()
env.render()
x = env.decode(state)
for i in x:
print(i)
该环境中动作有六个分别是:上、下、左、右、请乘客上车、请乘客下车。黄色方块表示没有乘客的汽车,如果有乘客则是绿色。汽车只能在虚线部分改变方向。R、G、Y、B是四个站点,代号分别是0、1、2、3蓝色代表乘客所在位置,红色代表乘客目的地。每个状态由一个元组:(taxirow, taxicol, passloc, destidx)表示,前两个表示了出租车的坐标如图所示在(3,1)处,第三、第四个元素分别是乘客所在位置和目的地的代号。因此共有 ( 5 × 5 ) × 5 × 4 = 500 (5 \times 5)\times 5 \times 4=500 (5×5)×5×4=500 个状态。每试图移动一次的收益是-1,错误地让乘客下车或上车收益是-10,顺利地完成一次任务收益是20,直到完成任务或者200步后还没能完成任务一幕就结束。
接下来定义一个agent类:
class SARSAagent:
def __init__(self, env, gamma=0.9, learning_rate=0.1, epsilon=0.1):
self.gamma = gamma
self.learning_rate = learning_rate
self.epsilon = epsilon
self.action_n = env.action_space.n
self.q_table = np.zeros((env.observation_space.n, env.action_space.n))
def use_epsilon_greedy_policy(self, state):
if np.random.uniform() > self.epsilon:
action = np.argmax(self.q_table[state])
else:
action = np.random.randint(self.action_n)
return action
def learn(self, state, action, reward, next_state, next_action, done):
td_target = reward + self.gamma * self.q_table[next_state][next_action] * (1. - done)
td_error = td_target - self.q_table[state][action]
self.q_table[state][action] += self.learning_rate * td_error
这里的
ϵ
−
g
r
e
e
d
y
p
o
l
i
c
y
\epsilon-greedy \ policy
ϵ−greedy policy 和MC算法实战中的定义有所区别,MC中是根据
ϵ
−
g
r
e
e
d
y
p
o
l
i
c
y
\epsilon-greedy \ policy
ϵ−greedy policy 本质上是一个分段函数来定义的,这里则是直接从定义出发,即在(0,1)按均匀分布随机生成一个数,如果它比
ϵ
\epsilon
ϵ 大则选择最优动作(相当于是 $1-\epsilon $ 的概率)否则就随便选一个动作。learn
方法则是完全遵照了SARSA的更新定义。
接下来就结合着SARSA的算法来定义execute_SARSA_one_episode
:
def execute_SARSA_one_episode(env, agnet, render = False):
total_steps, total_rewards = 0.0, 0.0
state = env.reset()
action = agent.use_epsilon_greedy_policy(state)
while True:
if render:
env.render()
clear_output(wait=True)
time.sleep(0.02)
next_state, reward, done, _ = env.step(action)
total_steps += 1.
total_rewards += reward
next_action = agent.use_epsilon_greedy_policy(next_state)
agent.learn(state, action, reward, next_state, next_action, done)
if done:
if render:
clear_output(wait=True)
print('END')
print('total_steps: ', total_steps)
time.sleep(3)
break
else:
state, action = next_state, next_action
return total_steps, total_rewards
每次对agent类的初始化中将Q_table中的每一个值都定义成0,在一幕中对每个“状态—动作”对按照SARSA算法进行一次更新,进行5000幕那么每个"状态—动作"对都将收敛。
这里用到了一个unzip的小技巧: list(zip(*result))
!
agent = SARSAagent(env)
result = [execute_SARSA_one_episode(env, agent) for _ in range(5000)]
unziped_resutl = list(zip(*result))
steps = list(unziped_resutl[0])
rewards = list(unziped_resutl[1])
然后利用pandas中的rolling平滑曲线并将其绘制出来
steps_smoothed = pd.Series(steps).rolling(20, min_periods=20).mean()
plt.figure(figsize=(15, 8))
plt.title('steps of each episode', fontsize=25, color='r')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.plot(steps_smoothed, color='b')
plt.savefig('SARSA_steps_of_each_episode.png')
可以看到随着训练的幕逐步增多,agent从一开始走完200步都不能完成任务到最后基本上走二十几步就能完成。
收益曲线也相应的逐幕收敛了。
rewards_smoothed = pd.Series(rewards).rolling(20,20).mean()
plt.figure(figsize=(15, 8))
plt.title('rewards of each episode', fontsize=25, color='r')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.plot(rewards_smoothed, color='b')
plt.savefig('SARSA_rewards_of_each_episode.png')
最后打印一下Q_table和policy
pd.DataFrame(agent.q_table)
0 | 1 | 2 | 3 | 4 | 5 | |
---|---|---|---|---|---|---|
0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
1 | -3.985126 | -3.424446 | -4.072099 | -3.732169 | -1.682856 | -8.041540 |
2 | -1.466943 | 0.117121 | 0.119486 | 0.656357 | 4.190408 | -4.835304 |
3 | -3.140358 | -2.997036 | -3.705371 | -2.777566 | -0.355685 | -8.073674 |
4 | -6.745418 | -6.919960 | -6.771730 | -6.735160 | -9.493358 | -10.074422 |
... | ... | ... | ... | ... | ... | ... |
495 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
496 | -2.881833 | -2.833192 | -2.878966 | -2.271016 | -3.764418 | -3.745854 |
497 | -1.783886 | -1.318232 | -1.769246 | -0.026576 | -1.909000 | -2.861079 |
498 | -3.281427 | -3.167514 | -3.254752 | -2.458533 | -4.899401 | -4.452290 |
499 | -0.190000 | -0.199000 | 0.757882 | 14.562293 | -2.731714 | -1.000000 |
500 rows × 6 columns
policy = np.eye(agent.action_n)[agent.q_table.argmax(axis=-1)]
policy
array([[1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1., 0.],
...,
[0., 0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0., 0.]])
参考资料
《强化学习原理与Python实现》肖智清