强化学习7--TD and SARSA

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

个人笔记,恳请纠错,请勿转载


TD learning

一、Temporal-Difference:what?

首先我想给出我认为TD和动态规划算法DP最核心的区别:policy evalution的公式不同
DP:在这里插入图片描述
这段代码的含义是将得到的states_and_rewards翻转,从最后的state_value开始计算, 具体的计算公式为G = r + GAMMA * G, 就是常见的求state_value的公式。

TD:
在这里插入图片描述
这是TD算法的数学公式,它利用梯度下降法来进行state_value的更新,其中r_t +1 + γv_t(s_t + 1)实际上就是在state s_t到达s_t+1获得的reward,而不是整个探索过程的return。

形象一点讲这个事情就是下图所示:
在这里插入图片描述
1.MC 算法使用的是actual G_t(固定初始policy),而TD只是应用了单步的r_t +1 + γv_t(s_t + 1),DP算法则是不固定初始policy的任意探索后返回的state value。

二、Convergence:why?

在这里插入图片描述
给出结论:v_t最终无论如何都会收敛到v_π,具体证明看书。

三.Code

代码最主要的两部分,一部分是play_game(),游戏的过程。第二部分是迭代:

def play_game(grid, policy):
  # returns a list of states and corresponding rewards (not returns as in MC)
  # start at the designated start state
  s = (2, 0)
  grid.set_state(s)
  states_and_rewards = [(s, 0)]  # list of tuples of (state, reward)
  while not grid.game_over():
    a = policy[s]
    a = random_action(a)
    r = grid.move(a)
    s = grid.current_state()
    states_and_rewards.append((s, r))
  return states_and_rewards

这是第一部分,网格世界游戏过程中我们分别计算每一步的reward,并最终以列表的形式进行返回。和MC最大的区别在于没有计算return,这也是TD的核心。

for it in range(1000):
  # generate an episode using pi
  states_and_rewards = play_game(grid, policy)
  # the first (s, r) tuple is the state we start in and 0
  # (since we don't get a reward) for simply starting the game
  # the last (s, r) tuple is the terminal state and the final reward
  # the value for the terminal state is by definition 0, so we don't
  # care about updating it.
  for t in range(len(states_and_rewards) - 1):
    s, _ = states_and_rewards[t]
    s2, r = states_and_rewards[t+1]
    # we will update V(s) AS we experience the episode
    V[s] = V[s] + ALPHA*(r + GAMMA*V[s2] - V[s])

迭代的过程中,对每一次游戏过程的每一个s_t都计算它的reward,并进行state value更新,即使更新使用的不是真正的return,但其state value 最终也会收敛。循环1000次。

SARSA

一、Sarsa:what?

个人学习算法,我认为最好的方法就是与已学算法做比较。
在这里插入图片描述

它与基础TD算法的区别就是用动作价值函数来代替状态价值函数。
SARSA的本质目的是为了解决这样一个贝尔曼公式:在这里插入图片描述
这个贝尔曼公式也使用了action value来代替state value。
具体证明如下:
在这里插入图片描述
使用的时候只要知道它是一个贝尔曼公式就行。

SARSA和基础的TD一样也是收敛的:
在这里插入图片描述

二、Sarsa:optimal policy code

伪代码如下:
在这里插入图片描述
code:

grid = negative_grid(step_cost = -0.1)
Q = {}
states = grid.all_states()
for s in states:
  Q[s] = {}
  for a in ALL_POSSIBLE_ACTIONS:
    Q[s][a] = 0

update_counts = {}
update_counts_sa = {}
for s in states:
  update_counts_sa[s] = {}
  for a in ALL_POSSIBLE_ACTIONS:
    update_counts_sa[s][a] = 1.0

t = 1.0
deltas = []
for it in range(10000):
  if it % 100 == 0:
    t += 1e-2
  if it % 2000 == 0:
    print("iteration:", it)

  s = (2, 0)
  grid.set_state(s)
  a = max_dict(Q[s])[0]
  a = random_action(a, eps = 0.5/t)
  biggest_change = 0
  while not grid.game_over():
    r = grid.move(a)
    s2 = grid.current_state()
    a2 = max_dict(Q[s2])[0]
    a2 = random_action(a2, eps = 0.5/t)

    alpha = ALPHA / update_counts_sa[s][a]
    update_counts_sa[s][a] += 0.005
    old_qsa = Q[s][a]
    Q[s][a] = Q[s][a] + alpha*(r+GAMMA*Q[s2][a2] - Q[s][a])
    biggest_change = max(biggest_change, np.abs(old_qsa - Q[s][a]))
    update_counts[s] = update_counts.get(s,0) + 1
    s = s2
    a = a2
  deltas.append(biggest_change)

policy = {}
V = {}
for s in grid.actions.keys():
  a, max_q = max_dict(Q[s])
  policy[s] = a
  V[s] = max_q
print("update counts:")
total = np.sum(list(update_counts.values()))
for k, v in update_counts.items():
  update_counts[k] = float(v) / total
print_values(update_counts, grid)

print("final values:")
print_values(V, grid)
print("final policy:")
print_policy(policy, grid)
  

tip:上面的代码是先迭代,最后优化policy。
result:
在这里插入图片描述

N-step Sarsa

在这里插入图片描述
假设1<n<m,n对应为单个episdoe的第n步,MC算法就是n取无穷的n-step Sarsa,Sarsa就是n取1的n-step Sarsa。

总结

摘录于:https://github.com/rookiexxj/Reinforcement_learning_tutorial_with_demo/blob/master/SARSA_demo.ipynb
https://www.bilibili.com/video/BV1sd4y167NS/?spm_id_from=333.337.search-card.all.click

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

rookiexxj01

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值