深度强化学习(四)SARSA

深度强化学习(四)SARSA算法

一.SARSA

假设状态空间 S \mathcal{S} S 和动作空间 A \mathcal{A} A 都是有限集, 即集合中元素数量有限。比如, S \mathcal{S} S 中一共有 3 种状态, A \mathcal{A} A 中一共有 4 种动作。那么动作价值函数 Q π ( s , a ) Q_\pi(s, a) Qπ(s,a) 可以表示为一个 3 × 4 3 \times 4 3×4 的表格。该表格与一个策略函数 π ( a ∣ s ) \pi(a \mid s) π(as) 相关联; 如果 π \pi π 发生变化,表格 Q π Q_\pi Qπ 也会发生变化。

我们用表格 q q q 近似 Q π Q_\pi Qπ 。首先初始化 q q q, 可以让它是全零的表格。然后用表格形式的 SARSA 算法更新 q q q,每次更新表格的一个元素。最终 q q q 收敛到 Q π Q_\pi Qπ

SARSA 算法由下面的贝尔曼方程推导出 :
Q π ( s t , a t ) = E S t + 1 , A t + 1 [ R t + γ ⋅ Q π ( S t + 1 , A t + 1 ) ∣ S t = s t , A t = a t ] Q_\pi\left(s_t, a_t\right)=\mathbb{E}_{S_{t+1}, A_{t+1}}\left[R_t+\gamma \cdot Q_\pi\left(S_{t+1}, A_{t+1}\right) \mid S_t=s_t, A_t=a_t\right] Qπ(st,at)=ESt+1,At+1[Rt+γQπ(St+1,At+1)St=st,At=at]

我们对贝尔曼方程左右两边做近似:

  • 方程左边的 Q π ( s t , a t ) Q_\pi\left(s_t, a_t\right) Qπ(st,at) 可以近似成 q ( s t , a t ) 。 q ( s t , a t ) q\left(s_t, a_t\right) 。 q\left(s_t, a_t\right) q(st,at)q(st,at) 是表格在 t t t 时刻对 Q π ( s t , a t ) Q_\pi\left(s_t, a_t\right) Qπ(st,at)做出的估计。
  • 方程右边的期望是关于下一时刻状态 S t + 1 S_{t+1} St+1 和动作 A t + 1 A_{t+1} At+1 求的。给定当前状态 s t s_t st, 智能体执行动作 a t a_t at, 环境会给出奖励 r t r_t rt 和新的状态 s t + 1 s_{t+1} st+1 。然后基于 s t + 1 s_{t+1} st+1 做随机抽样,得到新的动作

a ~ t + 1 ∼ π ( ⋅ ∣ s t + 1 ) . \tilde{a}_{t+1} \sim \pi\left(\cdot \mid s_{t+1}\right) . a~t+1π(st+1).

用观测到的 r t 、 s t + 1 r_t 、 s_{t+1} rtst+1 和计算出的 a ~ t + 1 \tilde{a}_{t+1} a~t+1 对期望做蒙特卡洛近似, 得到:
r t + γ ⋅ Q π ( s t + 1 , a ~ t + 1 ) . r_t+\gamma \cdot Q_\pi\left(s_{t+1}, \tilde{a}_{t+1}\right) . rt+γQπ(st+1,a~t+1).

  • 进一步把公式 (5.1) 中的 Q π Q_\pi Qπ 近似成 q q q, 得到

y ^ t ≜ r t + γ ⋅ q ( s t + 1 , a ~ t + 1 ) . \widehat{y}_t \triangleq r_t+\gamma \cdot q\left(s_{t+1}, \tilde{a}_{t+1}\right) . y trt+γq(st+1,a~t+1).

把它称作 TD 目标。它是表格在 t + 1 t+1 t+1 时刻对 Q π ( s t , a t ) Q_\pi\left(s_t, a_t\right) Qπ(st,at) 做出的估计。
q ( s t , a t ) q\left(s_t, a_t\right) q(st,at) y ^ t \widehat{y}_t y t 都是对动作价值 Q π ( s t , a t ) Q_\pi\left(s_t, a_t\right) Qπ(st,at) 的估计。由于 y ^ t \widehat{y}_t y t 部分基于真实观测到的奖励 r t r_t rt,我们认为 y ^ t \widehat{y}_t y t 是更可靠的估计, 所以鼓励 q ( s t , a t ) q\left(s_t, a_t\right) q(st,at) 趋近 y ^ t \widehat{y}_t y t 。更新表格 ( s t , a t ) \left(s_t, a_t\right) (st,at) 位置上的元素:
q ( s t , a t ) ← ( 1 − α ) ⋅ q ( s t , a t ) + α ⋅ y ^ t . q\left(s_t, a_t\right) \leftarrow(1-\alpha) \cdot q\left(s_t, a_t\right)+\alpha \cdot \widehat{y}_t . q(st,at)(1α)q(st,at)+αy t.

这样可以使得 q ( s t , a t ) q\left(s_t, a_t\right) q(st,at) 更接近 y ^ t \widehat{y}_t y t 。 SARSA 算法用到了这个五元组: ( s t , a t , r t , s t + 1 , a ~ t + 1 ) \left(s_t, a_t, r_t, s_{t+1}, \tilde{a}_{t+1}\right) (st,at,rt,st+1,a~t+1) 。SARSA 算法学到的 q q q 依赖于策略 π \pi π, 这是因为五元组中的 a ~ t + 1 \tilde{a}_{t+1} a~t+1 是根据 π ( ⋅ ∣ s t + 1 ) \pi\left(\cdot \mid s_{t+1}\right) π(st+1) 抽样得到的。

训练流程:设当前表格为 q now  q_{\text {now }} qnow , 当前策略为 π now  \pi_{\text {now }} πnow  每一轮更新表格中的一个元素,把更新之后的表格记作 q new  q_{\text {new }} qnew 

  1. 观测到当前状态 s t s_t st, 根据当前策略做抽样: a t ∼ π now  ( ⋅ ∣ s t ) a_t \sim \pi_{\text {now }}\left(\cdot \mid s_t\right) atπnow (st)
  2. 把表格 q now  q_{\text {now }} qnow  中第 ( s t , a t ) \left(s_t, a_t\right) (st,at) 位置上的元素记作:

q ^ t = q now  ( s t , a t ) . \widehat{q}_t=q_{\text {now }}\left(s_t, a_t\right) . q t=qnow (st,at).

  1. 智能体执行动作 a t a_t at 之后, 观测到奖励 r t r_t rt 和新的状态 s t + 1 s_{t+1} st+1
  2. 根据当前策略做抽样: a ~ t + 1 ∼ π now  ( ⋅ ∣ s t + 1 ) \tilde{a}_{t+1} \sim \pi_{\text {now }}\left(\cdot \mid s_{t+1}\right) a~t+1πnow (st+1) 。注意, a ~ t + 1 \tilde{a}_{t+1} a~t+1 只是假想的动作, 智能体不予执行。
  3. 把表格 q now  q_{\text {now }} qnow  中第 ( s t + 1 , a ~ t + 1 ) \left(s_{t+1}, \tilde{a}_{t+1}\right) (st+1,a~t+1) 位置上的元素记作:

q ^ t + 1 = q now  ( s t + 1 , a ~ t + 1 ) . \widehat{q}_{t+1}=q_{\text {now }}\left(s_{t+1}, \tilde{a}_{t+1}\right) . q t+1=qnow (st+1,a~t+1).

  1. 计算 TD 目标和 TD 误差:

y ^ t = r t + γ ⋅ q ^ t + 1 , δ t = q ^ t − y ^ t . \widehat{y}_t=r_t+\gamma \cdot \widehat{q}_{t+1}, \quad \delta_t=\widehat{q}_t-\widehat{y}_t . y t=rt+γq t+1,δt=q ty t.

  1. 更新表格中 ( s t , a t ) \left(s_t, a_t\right) (st,at) 位置上的元素:

q new  ( s t , a t ) ← q now  ( s t , a t ) − α ⋅ δ t . q_{\text {new }}\left(s_t, a_t\right) \leftarrow q_{\text {now }}\left(s_t, a_t\right)-\alpha \cdot \delta_t . qnew (st,at)qnow (st,at)αδt.

  1. 用某种算法更新策略函数。该算法与 SARSA 算法无关。

二.神经网络形式的SARSA

**价值网络:**如果状态空间 S \mathcal{S} S 是无限集, 那么我们无法用一张表格表示 Q π Q_\pi Qπ, 否则表格的行数是无穷。一种可行的方案是用一个神经网络 q ( s , a ; w ) q(s, a ; \boldsymbol{w}) q(s,a;w) 来近似 Q π ( s , a ) Q_\pi(s, a) Qπ(s,a); 理想情况下,
q ( s , a ; w ) = Q π ( s , a ) , ∀ s ∈ S , a ∈ A q(s, a ; \boldsymbol{w})=Q_\pi(s, a), \quad \forall s \in \mathcal{S}, a \in \mathcal{A} q(s,a;w)=Qπ(s,a),sS,aA
训练流程 : 设当前价值网络的参数为 w n o w \boldsymbol{w}_{\mathrm{now}} wnow, 当前策略为 π n o w \pi_{\mathrm{now}} πnow 每一轮训练用五元组 ( s t , a t , r t , s t + 1 , a ~ t + 1 ) \left(s_t, a_t, r_t, s_{t+1}, \tilde{a}_{t+1}\right) (st,at,rt,st+1,a~t+1) 对价值网络参数做一次更新。

  1. 观测到当前状态 s t s_t st, 根据当前策略做抽样: a t ∼ π now  ( ⋅ ∣ s t ) a_t \sim \pi_{\text {now }}\left(\cdot \mid s_t\right) atπnow (st)
  2. 用价值网络计算 ( s t , a t ) \left(s_t, a_t\right) (st,at) 的价值:

q ^ t = q ( s t , a t ; w now  ) . \widehat{q}_t=q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) . q t=q(st,at;wnow ).

  1. 智能体执行动作 a t a_t at 之后, 观测到奖励 r t r_t rt 和新的状态 s t + 1 s_{t+1} st+1
  2. 根据当前策略做抽样: a ~ t + 1 ∼ π n o w ( ⋅ ∣ s t + 1 ) \tilde{a}_{t+1} \sim \pi_{\mathrm{now}}\left(\cdot \mid s_{t+1}\right) a~t+1πnow(st+1) 。注意, a ~ t + 1 \tilde{a}_{t+1} a~t+1 只是假想的动作, 智能体不予执行。
  3. 用价值网络计算 ( s t + 1 , a ~ t + 1 ) \left(s_{t+1}, \tilde{a}_{t+1}\right) (st+1,a~t+1) 的价值:

q ^ t + 1 = q ( s t + 1 , a ~ t + 1 ; w now  ) . \widehat{q}_{t+1}=q\left(s_{t+1}, \tilde{a}_{t+1} ; \boldsymbol{w}_{\text {now }}\right) . q t+1=q(st+1,a~t+1;wnow ).

  1. 计算 TD 目标和 TD 误差:

y ^ t = r t + γ ⋅ q ^ t + 1 , δ t = q ^ t − y ^ t . \widehat{y}_t=r_t+\gamma \cdot \widehat{q}_{t+1}, \quad \delta_t=\widehat{q}_t-\widehat{y}_t . y t=rt+γq t+1,δt=q ty t.

  1. 对价值网络 q q q 做反向传播, 计算 q q q 关于 w \boldsymbol{w} w 的梯度: ∇ w q ( s t , a t ; w now  ) \nabla_{\boldsymbol{w}} q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) wq(st,at;wnow )
  2. 更新价值网络参数:

w new  ← w now  − α ⋅ δ t ⋅ ∇ w q ( s t , a t ; w now  ) . \boldsymbol{w}_{\text {new }} \leftarrow \boldsymbol{w}_{\text {now }}-\alpha \cdot \delta_t \cdot \nabla_{\boldsymbol{w}} q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) . wnew wnow αδtwq(st,at;wnow ).

  1. 用某种算法更新策略函数。该算法与 SARSA 算法无关。

三.多步TD目标

在第二节我们证明了以下定理

R k R_k Rk S k 、 A k 、 S k + 1 S_k 、 A_k 、 S_{k+1} SkAkSk+1 的函数, ∀ k = 1 , ⋯   , n \forall k=1, \cdots, n k=1,,n 。 那么
Q π ( s t , a t ) ⏟ U t  的期望  = E S t + 1 , A t + 1 , ⋯   , S t + m , A t + m [ ( ∑ i = 0 m − 1 γ i R t + i ) + γ m ⋅ Q π ( S t + m , A t + m ) ⏟ U t + m  的期望  ∣ S t = s t , A t = a t ] . \underbrace{Q_\pi\left(s_t, a_t\right)}_{U_t \text { 的期望 }}=\mathbb{E}_{S_{t+1}, A_{t+1}, \cdots, S_{t+m}, A_{t+m}}[\left(\sum_{i=0}^{m-1} \gamma^i R_{t+i}\right)+\gamma^m \cdot \underbrace{Q_\pi\left(S_{t+m}, A_{t+m}\right)}_{U_{t+m} \text { 的期望 }} \mid S_t=s_t, A_t=a_t] . Ut 的期望  Qπ(st,at)=ESt+1,At+1,,St+m,At+m[(i=0m1γiRt+i)+γmUt+m 的期望  Qπ(St+m,At+m)St=st,At=at].

已知当前状态 s t s_t st, 用策略 π \pi π 控制智能体与环境交互 m m m 次, 得到轨迹
a t , r t , s t + 1 , a t + 1 , r t + 1 , ⋯   , s t + m − 1 , a t + m − 1 , r t + m − 1 , s t + m , a t + m . a_t,r_t, s_{t+1}, a_{t+1}, r_{t+1}, \cdots, s_{t+m-1}, a_{t+m-1}, r_{t+m-1}, s_{t+m}, a_{t+m} . at,rt,st+1,at+1,rt+1,,st+m1,at+m1,rt+m1,st+m,at+m.

t + m t+m t+m 时刻, 用观测到的轨迹对上式中的期望做蒙特卡洛近似, 把近似的结果记作:
( ∑ i = 0 m − 1 γ i r t + i ) + γ m ⋅ Q π ( s t + m , a t + m ) . \left(\sum_{i=0}^{m-1} \gamma^i r_{t+i}\right)+\gamma^m \cdot Q_\pi\left(s_{t+m}, a_{t+m}\right) . (i=0m1γirt+i)+γmQπ(st+m,at+m).

进一步用 q ( s t + m , a t + m ; w ) q\left(s_{t+m}, a_{t+m} ; \boldsymbol{w}\right) q(st+m,at+m;w) 近似 Q π ( s t + m , a t + m ) Q_\pi\left(s_{t+m}, a_{t+m}\right) Qπ(st+m,at+m), 得到:
y ^ t ≜ ( ∑ i = 0 m − 1 γ i r t + i ) + γ m ⋅ q ( s t + m , a t + m ; w ) . \widehat{y}_t \triangleq\left(\sum_{i=0}^{m-1} \gamma^i r_{t+i}\right)+\gamma^m \cdot q\left(s_{t+m}, a_{t+m} ; \boldsymbol{w}\right) . y t(i=0m1γirt+i)+γmq(st+m,at+m;w).

y ^ t \widehat{y}_t y t 称作 m m m 步 TD 目标。

q ^ t = q ( s t , a t ; w ) \widehat{q}_t=q\left(s_t, a_t ; \boldsymbol{w}\right) q t=q(st,at;w) y ^ t \widehat{y}_t y t 分别是价值网络在 t t t 时刻和 t + m t+m t+m 时刻做出的预测, 两者都是对 Q π ( s t , a t ) Q_\pi\left(s_t, a_t\right) Qπ(st,at) 的估计值。 q ^ t \widehat{q}_t q t 是纯粹的预测, 而 y ^ t \widehat{y}_t y t 则基于 m m m 组实际观测, 因此 y ^ t \widehat{y}_t y t q ^ t \widehat{q}_t q t 更可靠。我们鼓励 q ^ t \widehat{q}_t q t 接近 y ^ t \widehat{y}_t y t 。设损失函数为
L ( w ) ≜ 1 2 [ q ( s t , a t ; w ) − y t ^ ] 2 . L(\boldsymbol{w}) \triangleq \frac{1}{2}\left[q\left(s_t, a_t ; \boldsymbol{w}\right)-\widehat{y_t}\right]^2 . L(w)21[q(st,at;w)yt ]2.

做一步梯度下降更新价值网络参数 w \boldsymbol{w} w :
w ← w − α ⋅ ( q ^ t − y ^ t ) ⋅ ∇ w q ( s t , a t ; w ) . \boldsymbol{w} \leftarrow \boldsymbol{w}-\alpha \cdot\left(\widehat{q}_t-\widehat{y}_t\right) \cdot \nabla_{\boldsymbol{w}} q\left(s_t, a_t ; \boldsymbol{w}\right) . wwα(q ty t)wq(st,at;w).
训练流程 : 设当前价值网络的参数为 w now  \boldsymbol{w}_{\text {now }} wnow , 当前策略为 π now  \pi_{\text {now }} πnow  执行以下步骤更新价值网络和策略。

  1. 用策略网络 π now  \pi_{\text {now }} πnow  控制智能体与环境交互, 完成一个回合, 得到轨迹:

s 1 , a 1 , r 1 , s 2 , a 2 , r 2 , ⋯   , s n , a n , r n . s_1, a_1, r_1, s_2, a_2, r_2, \cdots, s_n, a_n, r_n . s1,a1,r1,s2,a2,r2,,sn,an,rn.

  1. 对于所有的 t = 1 , ⋯   , n − m t=1, \cdots, n-m t=1,,nm, 计算

q ^ t = q ( s t , a t ; w now  ) . \widehat{q}_t=q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) . q t=q(st,at;wnow ).

  1. 对于所有的 t = 1 , ⋯   , n − m t=1, \cdots, n-m t=1,,nm, 计算多步 TD 目标和 TD 误差:

y ^ t = ∑ i = 0 m − 1 γ i r t + i + γ m q ^ t + m , δ t = q ^ t − y ^ t . \widehat{y}_t=\sum_{i=0}^{m-1} \gamma^i r_{t+i}+\gamma^m \widehat{q}_{t+m}, \quad \delta_t=\widehat{q}_t-\widehat{y}_t . y t=i=0m1γirt+i+γmq t+m,δt=q ty t.

  1. 对于所有的 t = 1 , ⋯   , n − m t=1, \cdots, n-m t=1,,nm, 对价值网络 q q q 做反向传播, 计算 q q q 关于 w \boldsymbol{w} w 的梯度:

∇ w q ( s t , a t ; w now  ) . \nabla_{\boldsymbol{w}} q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) . wq(st,at;wnow ).

  1. 更新价值网络参数:

w new  ← w now  − α ⋅ ∑ t = 1 n − m δ t ⋅ ∇ w q ( s t , a t ; w now  ) . \boldsymbol{w}_{\text {new }} \leftarrow \boldsymbol{w}_{\text {now }}-\alpha \cdot \sum_{t=1}^{n-m} \delta_t \cdot \nabla_{\boldsymbol{w}} q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) . wnew wnow αt=1nmδtwq(st,at;wnow ).

  1. 用某种算法更新策略函数 π \pi π 。该算法与 SARSA 算法无关。
  • 20
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值