SARSA 算法

SARSA,它是一种 TD 算法,SARSA 的目的是学习动作价值函数 Q π ( s , a ) Q_π(s, a) Qπ(s,a)
Q π Q_π Qπ 通常被用于评价策略的好坏,而非用于控制智能体。 Q π Q_π Qπ 常与策略函数 π 结合使用,被称作 actor-critic(演员—评委)方法。策略函数 π 控制智能体,因此被看做“演员”;而 Q π Q_π Qπ 评价 π 的表现,帮助改进 π,因此 Q π Q_π Qπ 被看做“评委”。Actor-critic 通常用 SARSA 训练“评委” Q π Q_π Qπ

SARSA 算法的推导

SARSA 算法由下面的贝尔曼方程推导出:
在这里插入图片描述

用一个神经网络 q(s, a; w) 来近似 Q π ( s , a ) Q_π(s, a) Qπ(s,a)
给定当前状态 s t s_t st,智能体执行动作 a t a_t at,环境会给出奖励 r t r_t rt 和新的状态 s t + 1 s_{t+1} st+1。然后基于 s t + 1 s_{t+1} st+1 做随机抽样,得到新的动作 a t + 1 ~ ∼ π ( ⋅ ∣ s t + 1 ) \tilde{a_{t+1}} \sim π(· | s_{t+1}) at+1~π(st+1)。定义 TD 目标:
y t ^ ≜ r t + γ ⋅ q ( s t + 1 , a t + 1 ^ ; w ) \hat{y_t} ≜ r_t + \gamma \cdot q(s_{t+1}, \hat{a_{t+1}}; w) yt^rt+γq(st+1,at+1^;w)
我们鼓励 q ( s t , a t ; w ) q(s_t, a_t; w) q(st,at;w) 接近 TD 目标 y t ^ \hat{y_t} yt^,所以定义损失函数:
L ( w ) ≜ 1 2 [ q ( s t , a t ; w ) − y t ^ ] 2 L(w) ≜ \frac{1}{2} [q(s_t, a_t;w) - \hat{y_t}]^2 L(w)21[q(st,at;w)yt^]2
损失函数的变量是 w,而 y t ^ \hat{y_t} yt^被视为常数。(尽管 y t ^ \hat{y_t} yt^也依赖于参数 w,但这一点被忽略掉。)设 q t ^ \hat{q_t} qt^= q ( s t , a t ; w ) q(s_t, a_t; w) q(st,at;w)。损失函数关于 w 的梯度是:
∇ w L ( w ) = ( q t ^ − y t ^ ) ⏟ T D 误差 δ t ⋅ ∇ w q ( s t , a t ; w ) \nabla_{w} L(w) = \underbrace{(\hat{q_t} - \hat{y_t})}_{TD误差\delta_t} \cdot \nabla_{w}q(s_t, a_t; w) wL(w)=TD误差δt (qt^yt^)wq(st,at;w)
做一次梯度下降更新 w:
w ← w − α ⋅ δ t ⋅ ∇ w q ( s t , a t ; w ) w \leftarrow w - \alpha \cdot \delta_t \cdot \nabla_wq(s_t, a_t; w) wwαδtwq(st,at;w)

训练流程

设当前价值网络的参数为 w n o w w_{now} wnow,当前策略为 π n o w π_{now} πnow。每一轮训练用五元组 ( s t , a t , r t , s t + 1 , a t + 1 ~ ) (s_t, a_t, r_t, s_{t+1}, \tilde{a_{t+1}}) (st,at,rt,st+1,at+1~) 对价值网络参数做一次更新。

  1. 观测到当前状态 s t s_t st,根据当前策略做抽样: a t ∼ π n o w ( ⋅ ∣ s t ) a_t ∼ π_{now}(· | s_t) atπnow(st)
  2. 用价值网络计算 ( s t , a t ) (s_t, a_t) (st,at) 的价值:
    q t ^ = q ( s t , a t ; w n o w ) \hat{q_t} = q(s_t, a_t; w_{now}) qt^=q(st,at;wnow)
  3. 智能体执行动作 a t a_t at 之后,观测到奖励 r t r_t rt 和新的状态 s t + 1 s_{t+1} st+1
  4. 根据当前策略做抽样: a t + 1 ^ ∼ π n o w ( ⋅ ∣ s t + 1 ) \hat{a_{t+1}} ∼ π_{now}(· |s_{t+1}) at+1^πnow(st+1)。注意, a t + 1 ^ \hat{a_{t+1}} at+1^ 只是假想的动作,智能体不予执行。
  5. 用价值网络计算 ( s t + 1 , a t + 1 ^ ) (s_{t+1}, \hat{a_{t+1}}) (st+1,at+1^) 的价值:
    q t + 1 ^ = q ( s t + 1 , a t + 1 ^ ; w n o w ) \hat{q_{t+1}} = q(s_{t+1}, \hat{a_{t+1}};w_{now}) qt+1^=q(st+1,at+1^;wnow)
  6. 计算 TD 目标和 TD 误差:
    y t ^ = r t + γ ⋅ q t + 1 ^ , δ t = q t ^ − y t ^ \hat{y_t} = r_t + \gamma \cdot \hat{q_{t+1}}, \delta_t = \hat{q_t} - \hat{y_t} yt^=rt+γqt+1^,δt=qt^yt^
  7. 对价值网络 q 做反向传播,计算 q 关于 w 的梯度:
    ∇ w q ( s t , a t ; w n o w ) \nabla_{w}q(s_t, a_t; w_{now}) wq(st,at;wnow)
  8. 更新价值网络参数:
    w n e w ← w n o w − α ⋅ δ t ⋅ ∇ w q ( s t , a t ; w n o w ) w_{new} \leftarrow w_{now} - \alpha \cdot \delta_t \cdot \nabla_wq(s_t, a_t; w_{now}) wnewwnowαδtwq(st,at;wnow)
  9. 用某种算法更新策略函数。该算法与 SARSA 算法无关

参考 https://github.com/wangshusen/DRL

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值