时序差分更新算法和回合更新算法一样都是利用经验数据进行学习,其区别在于时序差分更新不必等到回合结束,可以用现有的价值估计值来更新。因此时序差分更新既可用于回合制任务,也可用于连续性任务。
同策时序差分更新
从给定策略的情况下动作价值函数的定义出发,我们可以得到:
单步时序差分只需要采样一步,用来估计回报样本的值,这里表示有偏回报样本,与回合更新中由奖励计算得到的无偏回报样本相区别。
基于以上分析,我们可以定义单步时序差分目标为:
这里的U的上标(q)表示是对动作价值定义的,下标t:t+1表示用的估计值来估计。如果是终止状态,默认有。如果把单步扩展到多步,则n步时序差分目标定义为:
类似于回合更新策略,单步时序更新可以表示为:
SARSA算法
最优策略的求解算法如下:
1. (初始化)任意值,。如果有终止状态,令
2.(时序差分更新)对每个回合执行以下操作
2.1(初始化状态动作对)选择状态S,再用策略确定动作A
2.2 如果回合未结束(如未达到最大步数,S不是终止状态),执行以下操作:
2.2.1(采样)执行动作A,观测得到奖励R和新状态
2.2.2 用动作价值确定的策略决定动作(如使用贪心策略)
2.2.3(计算回报的估计值)
2.2.4(更新价值)更新以减小(如)
2.2.5
期望SARSA算法
期望SARSA算法与SARSA算法的不同之处在于使用基于状态价值而不是动作价值来估计,即:
这个算法虽然运算量比SARSA要大,但是可以减小SARSA算法中出现的个别不恰当决策。其最优策略的求解算法如下:
1. (初始化)任意值,。如果有终止状态,令。用动作价值确定策略(如使用贪心策略)
2.(时序差分更新)对每个回合执行以下操作
2.1(初始化状态动作对)选择状态S
2.2 如果回合未结束(如未达到最大步数,S不是终止状态),执行以下操作:
2.2.1 用动作价值确定的策略决定动作A(如使用贪心策略)
2.2.2(采样)执行动作A,观测得到奖励R和新状态
2.2.3(用期望计算回报的估计值)
2.2.4(更新价值)更新以减小(如)
2.2.5
异策时序差分更新
异策时序差分更新是比同策时序差分更新更加流行的算法,特别是Q学习算法。
Q学习算法
Q学习算法的思路是,在根据估计时,与其使用或,还不如使用根据改进后的策略来更新,因为可以更接近最优价值。因此Q学习的更新不是基于当前的策略,而是基于另一个并不一定要使用的确定性策略来更新动作价值,从这个意义来看,Q学习是一个异策算法。求解最优策略的算法如下:
1. (初始化)任意值,。如果有终止状态,令
2.(时序差分更新)对每个回合执行以下操作
2.1(初始化状态动作对)选择状态S
2.2 如果回合未结束(如未达到最大步数,S不是终止状态),执行以下操作:
2.2.1 用动作价值确定的策略决定动作A(如使用贪心策略)
2.2.2(采样)执行动作A,观测得到奖励R和新状态
2.2.3(用改进后的策略计算回报的估计值)
2.2.4(更新价值)更新以减小(如)
2.2.5
双重Q学习
Q学习的一个问题是用来更新动作价值,会带来最大化偏差,使得估计的动作价值过大。我的理解是相当于跳到局部最优的场景。为此可以引入双重Q学习来消除偏差。双重Q学习使用两个独立的动作价值估计和,用或来代替Q学习中的。由于和是两个相互独立的估计,因此可以消除偏差。以下是双重Q学习的最优策略求解算法:
1. (初始化)令任意值,。如果有终止状态,令
2.(时序差分更新)对每个回合执行以下操作
2.1(初始化状态动作对)选择状态S
2.2 如果回合未结束(如未达到最大步数,S不是终止状态),执行以下操作:
2.2.1 用动作价值确定的策略决定动作A(如使用贪心策略)
2.2.2(采样)执行动作A,观测得到奖励R和新状态
2.2.3(随机选择更新或)以等概率选择或中的一个动作价值函数作为更新对象,记选择的是
2.2.4(用改进后的策略计算回报的估计值)
2.2.5(更新价值)更新以减小(如)
2.2.6
出租车调度案例
下面以Gym库中的出租车调度为例,用时序更新的方法来找到最优策略。
在一个5*5的地图中,有四个出租车停靠点,乘客会随机出现在其中一个停靠点,并想在任意一个停靠点下车。出租车会随机出现在25个位置中的其中一个。出租车需要移动自己的位置来接乘客上车,并把乘客运到乘客想下车的位置。出租车每次只能在地图范围内上下左右移动一格,有竖线阻挡的地方不能移动。每次试图移动的奖励是-1,不合理的邀请乘客上车或让乘客下车的奖励为-10。以下代码将初始化环境并打印地图:
import gym
env = gym.make('Taxi-v3')
state = env.reset()
taxirow, taxicol, passloc, destidx = env.unwrapped.decode(state)
print(taxirow, taxicol, passloc, destidx)
print(env.unwrapped.locs[passloc])
print(env.unwrapped.locs[destidx])
env.render()
输出结果如下:
1 0 1 3 (0, 4) (4, 3) +---------+ |R: | : :G| | : | : : | | : : : : | | | : | : | |Y| : |B: | +---------+
从以上结果可以看到,出租车位置的行为1,列为0,对应地图上黄色的长方形。4个位置R,G,Y,B分别对应0-3。乘客当前位置为G,对应为1。乘客要下车的停靠点为B,对应为3。
这个环境中观察到的状态是一个[0,500)的整数值,其中出租车的位置的行数是0-4,列数是0-4。乘客的位置是0-4,其中4表示乘客在车上。目的地的取值是0-3。因此总的状态数为5*5*5*4=500
动作取值是0-5,其中0-3表示上下左右,4表示乘客上车,5表示乘客下车。
SARSA算法求最优策略
代码如下:
import gym
import numpy as np
import matplotlib.pyplot as plt
env = gym.make('Taxi-v3')
class SARSAAgent:
def __init__(self, env, gamma=0.9, learning_rate=0.1, epsilon=0.01):
self.gamma = gamma
self.learning_rate = learning_rate
self.epsilon = epsilon
self.action_n = env.action_space.n
self.q = np.zeros((env.observation_space.n, env.action_space.n))
def decide(self, state):
if np.random.uniform()>self.epsilon:
action = self.q[state].argmax()
else:
action = np.random.randint(self.action_n)
return action
def learn(self, state, action, reward, next_state, done, next_action):
u = reward + self.gamma*self.q[next_state, next_action]*(1.-done)
td_error = u - self.q[state, action]
self.q[state, action] += self.learning_rate*td_error
agent = SARSAAgent(env)
def play_sarsa(env, agent, train=False, render=False):
episode_reward = 0
observation = env.reset()
action = agent.decide(observation)
while True:
if render:
env.render()
next_observation, reward, done, _ = env.step(action)
episode_reward += reward
next_action = agent.decide(next_observation)
if train:
agent.learn(observation, action, reward, next_observation, done, next_action)
if done:
break
observation, action = next_observation, next_action
return episode_reward
episodes = 5000
episode_rewards = []
for episode in range(episodes):
episode_reward = play_sarsa(env, agent, True)
episode_rewards.append(episode_reward)
plt.plot(episode_rewards)
训练完成之后,我们可以看看效果如何,执行代码
play_sarsa(env, agent, False, True)
输出结果如下:
+---------+ |R: | : :G| | : | : : | | : : : : | | | : | : | |Y| : |B: | +---------+ +---------+ |R: | : :G| | : | : : | | : : : : | | | : | : | |Y| : |B: | +---------+ (South) +---------+ |R: | : :G| | : | : : | | : : : : | | | : | : | |Y| : |B: | +---------+ (South) +---------+ |R: | : :G| | : | : : | | : : : : | | | : | : | |Y| : |B: | +---------+ (West) +---------+ |R: | : :G| | : | : : | | : : : : | | | : | : | |Y| : |B: | +---------+ (West) +---------+ |R: | : :G| | : | : : | | : : : : | | | : | : | |Y| : |B: | +---------+ (West) +---------+ |R: | : :G| | : | : : | | : : : : | | | : | : | |Y| : |B: | +---------+ (South) +---------+ |R: | : :G| | : | : : | | : : : : | | | : | : | |Y| : |B: | +---------+ (South) +---------+ |R: | : :G| | : | : : | | : : : : | | | : | : | |Y| : |B: | +---------+ (Pickup) +---------+ |R: | : :G| | : | : : | | : : : : | |_| : | : | |Y| : |B: | +---------+ (North) +---------+ |R: | : :G| | : | : : | |_: : : : | | | : | : | |Y| : |B: | +---------+ (North) +---------+ |R: | : :G| | : | : : | | :_: : : | | | : | : | |Y| : |B: | +---------+ (East) +---------+ |R: | : :G| | : | : : | | : :_: : | | | : | : | |Y| : |B: | +---------+ (East) +---------+ |R: | : :G| | : | : : | | : : :_: | | | : | : | |Y| : |B: | +---------+ (East) +---------+ |R: | : :G| | : | : : | | : : : : | | | : |_: | |Y| : |B: | +---------+ (South) +---------+ |R: | : :G| | : | : : | | : : : : | | | : | : | |Y| : |B: | +---------+ (South)