深度强化学习(四)SARSA算法
一.SARSA
假设状态空间 S \mathcal{S} S 和动作空间 A \mathcal{A} A 都是有限集, 即集合中元素数量有限。比如, S \mathcal{S} S 中一共有 3 种状态, A \mathcal{A} A 中一共有 4 种动作。那么动作价值函数 Q π ( s , a ) Q_\pi(s, a) Qπ(s,a) 可以表示为一个 3 × 4 3 \times 4 3×4 的表格。该表格与一个策略函数 π ( a ∣ s ) \pi(a \mid s) π(a∣s) 相关联; 如果 π \pi π 发生变化,表格 Q π Q_\pi Qπ 也会发生变化。
我们用表格 q q q 近似 Q π Q_\pi Qπ 。首先初始化 q q q, 可以让它是全零的表格。然后用表格形式的 SARSA 算法更新 q q q,每次更新表格的一个元素。最终 q q q 收敛到 Q π Q_\pi Qπ 。
SARSA 算法由下面的贝尔曼方程推导出 :
Q
π
(
s
t
,
a
t
)
=
E
S
t
+
1
,
A
t
+
1
[
R
t
+
γ
⋅
Q
π
(
S
t
+
1
,
A
t
+
1
)
∣
S
t
=
s
t
,
A
t
=
a
t
]
Q_\pi\left(s_t, a_t\right)=\mathbb{E}_{S_{t+1}, A_{t+1}}\left[R_t+\gamma \cdot Q_\pi\left(S_{t+1}, A_{t+1}\right) \mid S_t=s_t, A_t=a_t\right]
Qπ(st,at)=ESt+1,At+1[Rt+γ⋅Qπ(St+1,At+1)∣St=st,At=at]
我们对贝尔曼方程左右两边做近似:
- 方程左边的 Q π ( s t , a t ) Q_\pi\left(s_t, a_t\right) Qπ(st,at) 可以近似成 q ( s t , a t ) 。 q ( s t , a t ) q\left(s_t, a_t\right) 。 q\left(s_t, a_t\right) q(st,at)。q(st,at) 是表格在 t t t 时刻对 Q π ( s t , a t ) Q_\pi\left(s_t, a_t\right) Qπ(st,at)做出的估计。
- 方程右边的期望是关于下一时刻状态 S t + 1 S_{t+1} St+1 和动作 A t + 1 A_{t+1} At+1 求的。给定当前状态 s t s_t st, 智能体执行动作 a t a_t at, 环境会给出奖励 r t r_t rt 和新的状态 s t + 1 s_{t+1} st+1 。然后基于 s t + 1 s_{t+1} st+1 做随机抽样,得到新的动作
a ~ t + 1 ∼ π ( ⋅ ∣ s t + 1 ) . \tilde{a}_{t+1} \sim \pi\left(\cdot \mid s_{t+1}\right) . a~t+1∼π(⋅∣st+1).
用观测到的
r
t
、
s
t
+
1
r_t 、 s_{t+1}
rt、st+1 和计算出的
a
~
t
+
1
\tilde{a}_{t+1}
a~t+1 对期望做蒙特卡洛近似, 得到:
r
t
+
γ
⋅
Q
π
(
s
t
+
1
,
a
~
t
+
1
)
.
r_t+\gamma \cdot Q_\pi\left(s_{t+1}, \tilde{a}_{t+1}\right) .
rt+γ⋅Qπ(st+1,a~t+1).
- 进一步把公式 (5.1) 中的 Q π Q_\pi Qπ 近似成 q q q, 得到
y ^ t ≜ r t + γ ⋅ q ( s t + 1 , a ~ t + 1 ) . \widehat{y}_t \triangleq r_t+\gamma \cdot q\left(s_{t+1}, \tilde{a}_{t+1}\right) . y t≜rt+γ⋅q(st+1,a~t+1).
把它称作 TD 目标。它是表格在
t
+
1
t+1
t+1 时刻对
Q
π
(
s
t
,
a
t
)
Q_\pi\left(s_t, a_t\right)
Qπ(st,at) 做出的估计。
q
(
s
t
,
a
t
)
q\left(s_t, a_t\right)
q(st,at) 和
y
^
t
\widehat{y}_t
y
t 都是对动作价值
Q
π
(
s
t
,
a
t
)
Q_\pi\left(s_t, a_t\right)
Qπ(st,at) 的估计。由于
y
^
t
\widehat{y}_t
y
t 部分基于真实观测到的奖励
r
t
r_t
rt,我们认为
y
^
t
\widehat{y}_t
y
t 是更可靠的估计, 所以鼓励
q
(
s
t
,
a
t
)
q\left(s_t, a_t\right)
q(st,at) 趋近
y
^
t
\widehat{y}_t
y
t 。更新表格
(
s
t
,
a
t
)
\left(s_t, a_t\right)
(st,at) 位置上的元素:
q
(
s
t
,
a
t
)
←
(
1
−
α
)
⋅
q
(
s
t
,
a
t
)
+
α
⋅
y
^
t
.
q\left(s_t, a_t\right) \leftarrow(1-\alpha) \cdot q\left(s_t, a_t\right)+\alpha \cdot \widehat{y}_t .
q(st,at)←(1−α)⋅q(st,at)+α⋅y
t.
这样可以使得 q ( s t , a t ) q\left(s_t, a_t\right) q(st,at) 更接近 y ^ t \widehat{y}_t y t 。 SARSA 算法用到了这个五元组: ( s t , a t , r t , s t + 1 , a ~ t + 1 ) \left(s_t, a_t, r_t, s_{t+1}, \tilde{a}_{t+1}\right) (st,at,rt,st+1,a~t+1) 。SARSA 算法学到的 q q q 依赖于策略 π \pi π, 这是因为五元组中的 a ~ t + 1 \tilde{a}_{t+1} a~t+1 是根据 π ( ⋅ ∣ s t + 1 ) \pi\left(\cdot \mid s_{t+1}\right) π(⋅∣st+1) 抽样得到的。
训练流程:设当前表格为 q now q_{\text {now }} qnow , 当前策略为 π now \pi_{\text {now }} πnow 每一轮更新表格中的一个元素,把更新之后的表格记作 q new q_{\text {new }} qnew 。
- 观测到当前状态 s t s_t st, 根据当前策略做抽样: a t ∼ π now ( ⋅ ∣ s t ) a_t \sim \pi_{\text {now }}\left(\cdot \mid s_t\right) at∼πnow (⋅∣st) 。
- 把表格 q now q_{\text {now }} qnow 中第 ( s t , a t ) \left(s_t, a_t\right) (st,at) 位置上的元素记作:
q ^ t = q now ( s t , a t ) . \widehat{q}_t=q_{\text {now }}\left(s_t, a_t\right) . q t=qnow (st,at).
- 智能体执行动作 a t a_t at 之后, 观测到奖励 r t r_t rt 和新的状态 s t + 1 s_{t+1} st+1 。
- 根据当前策略做抽样: a ~ t + 1 ∼ π now ( ⋅ ∣ s t + 1 ) \tilde{a}_{t+1} \sim \pi_{\text {now }}\left(\cdot \mid s_{t+1}\right) a~t+1∼πnow (⋅∣st+1) 。注意, a ~ t + 1 \tilde{a}_{t+1} a~t+1 只是假想的动作, 智能体不予执行。
- 把表格 q now q_{\text {now }} qnow 中第 ( s t + 1 , a ~ t + 1 ) \left(s_{t+1}, \tilde{a}_{t+1}\right) (st+1,a~t+1) 位置上的元素记作:
q ^ t + 1 = q now ( s t + 1 , a ~ t + 1 ) . \widehat{q}_{t+1}=q_{\text {now }}\left(s_{t+1}, \tilde{a}_{t+1}\right) . q t+1=qnow (st+1,a~t+1).
- 计算 TD 目标和 TD 误差:
y ^ t = r t + γ ⋅ q ^ t + 1 , δ t = q ^ t − y ^ t . \widehat{y}_t=r_t+\gamma \cdot \widehat{q}_{t+1}, \quad \delta_t=\widehat{q}_t-\widehat{y}_t . y t=rt+γ⋅q t+1,δt=q t−y t.
- 更新表格中 ( s t , a t ) \left(s_t, a_t\right) (st,at) 位置上的元素:
q new ( s t , a t ) ← q now ( s t , a t ) − α ⋅ δ t . q_{\text {new }}\left(s_t, a_t\right) \leftarrow q_{\text {now }}\left(s_t, a_t\right)-\alpha \cdot \delta_t . qnew (st,at)←qnow (st,at)−α⋅δt.
- 用某种算法更新策略函数。该算法与 SARSA 算法无关。
二.神经网络形式的SARSA
**价值网络:**如果状态空间
S
\mathcal{S}
S 是无限集, 那么我们无法用一张表格表示
Q
π
Q_\pi
Qπ, 否则表格的行数是无穷。一种可行的方案是用一个神经网络
q
(
s
,
a
;
w
)
q(s, a ; \boldsymbol{w})
q(s,a;w) 来近似
Q
π
(
s
,
a
)
Q_\pi(s, a)
Qπ(s,a); 理想情况下,
q
(
s
,
a
;
w
)
=
Q
π
(
s
,
a
)
,
∀
s
∈
S
,
a
∈
A
q(s, a ; \boldsymbol{w})=Q_\pi(s, a), \quad \forall s \in \mathcal{S}, a \in \mathcal{A}
q(s,a;w)=Qπ(s,a),∀s∈S,a∈A
训练流程 : 设当前价值网络的参数为
w
n
o
w
\boldsymbol{w}_{\mathrm{now}}
wnow, 当前策略为
π
n
o
w
\pi_{\mathrm{now}}
πnow 每一轮训练用五元组
(
s
t
,
a
t
,
r
t
,
s
t
+
1
,
a
~
t
+
1
)
\left(s_t, a_t, r_t, s_{t+1}, \tilde{a}_{t+1}\right)
(st,at,rt,st+1,a~t+1) 对价值网络参数做一次更新。
- 观测到当前状态 s t s_t st, 根据当前策略做抽样: a t ∼ π now ( ⋅ ∣ s t ) a_t \sim \pi_{\text {now }}\left(\cdot \mid s_t\right) at∼πnow (⋅∣st) 。
- 用价值网络计算 ( s t , a t ) \left(s_t, a_t\right) (st,at) 的价值:
q ^ t = q ( s t , a t ; w now ) . \widehat{q}_t=q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) . q t=q(st,at;wnow ).
- 智能体执行动作 a t a_t at 之后, 观测到奖励 r t r_t rt 和新的状态 s t + 1 s_{t+1} st+1 。
- 根据当前策略做抽样: a ~ t + 1 ∼ π n o w ( ⋅ ∣ s t + 1 ) \tilde{a}_{t+1} \sim \pi_{\mathrm{now}}\left(\cdot \mid s_{t+1}\right) a~t+1∼πnow(⋅∣st+1) 。注意, a ~ t + 1 \tilde{a}_{t+1} a~t+1 只是假想的动作, 智能体不予执行。
- 用价值网络计算 ( s t + 1 , a ~ t + 1 ) \left(s_{t+1}, \tilde{a}_{t+1}\right) (st+1,a~t+1) 的价值:
q ^ t + 1 = q ( s t + 1 , a ~ t + 1 ; w now ) . \widehat{q}_{t+1}=q\left(s_{t+1}, \tilde{a}_{t+1} ; \boldsymbol{w}_{\text {now }}\right) . q t+1=q(st+1,a~t+1;wnow ).
- 计算 TD 目标和 TD 误差:
y ^ t = r t + γ ⋅ q ^ t + 1 , δ t = q ^ t − y ^ t . \widehat{y}_t=r_t+\gamma \cdot \widehat{q}_{t+1}, \quad \delta_t=\widehat{q}_t-\widehat{y}_t . y t=rt+γ⋅q t+1,δt=q t−y t.
- 对价值网络 q q q 做反向传播, 计算 q q q 关于 w \boldsymbol{w} w 的梯度: ∇ w q ( s t , a t ; w now ) \nabla_{\boldsymbol{w}} q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) ∇wq(st,at;wnow ) 。
- 更新价值网络参数:
w new ← w now − α ⋅ δ t ⋅ ∇ w q ( s t , a t ; w now ) . \boldsymbol{w}_{\text {new }} \leftarrow \boldsymbol{w}_{\text {now }}-\alpha \cdot \delta_t \cdot \nabla_{\boldsymbol{w}} q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) . wnew ←wnow −α⋅δt⋅∇wq(st,at;wnow ).
- 用某种算法更新策略函数。该算法与 SARSA 算法无关。
三.多步TD目标
在第二节我们证明了以下定理
设 R k R_k Rk 是 S k 、 A k 、 S k + 1 S_k 、 A_k 、 S_{k+1} Sk、Ak、Sk+1 的函数, ∀ k = 1 , ⋯ , n \forall k=1, \cdots, n ∀k=1,⋯,n 。 那么
Q π ( s t , a t ) ⏟ U t 的期望 = E S t + 1 , A t + 1 , ⋯ , S t + m , A t + m [ ( ∑ i = 0 m − 1 γ i R t + i ) + γ m ⋅ Q π ( S t + m , A t + m ) ⏟ U t + m 的期望 ∣ S t = s t , A t = a t ] . \underbrace{Q_\pi\left(s_t, a_t\right)}_{U_t \text { 的期望 }}=\mathbb{E}_{S_{t+1}, A_{t+1}, \cdots, S_{t+m}, A_{t+m}}[\left(\sum_{i=0}^{m-1} \gamma^i R_{t+i}\right)+\gamma^m \cdot \underbrace{Q_\pi\left(S_{t+m}, A_{t+m}\right)}_{U_{t+m} \text { 的期望 }} \mid S_t=s_t, A_t=a_t] . Ut 的期望 Qπ(st,at)=ESt+1,At+1,⋯,St+m,At+m[(i=0∑m−1γiRt+i)+γm⋅Ut+m 的期望 Qπ(St+m,At+m)∣St=st,At=at].
已知当前状态
s
t
s_t
st, 用策略
π
\pi
π 控制智能体与环境交互
m
m
m 次, 得到轨迹
a
t
,
r
t
,
s
t
+
1
,
a
t
+
1
,
r
t
+
1
,
⋯
,
s
t
+
m
−
1
,
a
t
+
m
−
1
,
r
t
+
m
−
1
,
s
t
+
m
,
a
t
+
m
.
a_t,r_t, s_{t+1}, a_{t+1}, r_{t+1}, \cdots, s_{t+m-1}, a_{t+m-1}, r_{t+m-1}, s_{t+m}, a_{t+m} .
at,rt,st+1,at+1,rt+1,⋯,st+m−1,at+m−1,rt+m−1,st+m,at+m.
在
t
+
m
t+m
t+m 时刻, 用观测到的轨迹对上式中的期望做蒙特卡洛近似, 把近似的结果记作:
(
∑
i
=
0
m
−
1
γ
i
r
t
+
i
)
+
γ
m
⋅
Q
π
(
s
t
+
m
,
a
t
+
m
)
.
\left(\sum_{i=0}^{m-1} \gamma^i r_{t+i}\right)+\gamma^m \cdot Q_\pi\left(s_{t+m}, a_{t+m}\right) .
(i=0∑m−1γirt+i)+γm⋅Qπ(st+m,at+m).
进一步用
q
(
s
t
+
m
,
a
t
+
m
;
w
)
q\left(s_{t+m}, a_{t+m} ; \boldsymbol{w}\right)
q(st+m,at+m;w) 近似
Q
π
(
s
t
+
m
,
a
t
+
m
)
Q_\pi\left(s_{t+m}, a_{t+m}\right)
Qπ(st+m,at+m), 得到:
y
^
t
≜
(
∑
i
=
0
m
−
1
γ
i
r
t
+
i
)
+
γ
m
⋅
q
(
s
t
+
m
,
a
t
+
m
;
w
)
.
\widehat{y}_t \triangleq\left(\sum_{i=0}^{m-1} \gamma^i r_{t+i}\right)+\gamma^m \cdot q\left(s_{t+m}, a_{t+m} ; \boldsymbol{w}\right) .
y
t≜(i=0∑m−1γirt+i)+γm⋅q(st+m,at+m;w).
把 y ^ t \widehat{y}_t y t 称作 m m m 步 TD 目标。
q
^
t
=
q
(
s
t
,
a
t
;
w
)
\widehat{q}_t=q\left(s_t, a_t ; \boldsymbol{w}\right)
q
t=q(st,at;w) 和
y
^
t
\widehat{y}_t
y
t 分别是价值网络在
t
t
t 时刻和
t
+
m
t+m
t+m 时刻做出的预测, 两者都是对
Q
π
(
s
t
,
a
t
)
Q_\pi\left(s_t, a_t\right)
Qπ(st,at) 的估计值。
q
^
t
\widehat{q}_t
q
t 是纯粹的预测, 而
y
^
t
\widehat{y}_t
y
t 则基于
m
m
m 组实际观测, 因此
y
^
t
\widehat{y}_t
y
t 比
q
^
t
\widehat{q}_t
q
t 更可靠。我们鼓励
q
^
t
\widehat{q}_t
q
t 接近
y
^
t
\widehat{y}_t
y
t 。设损失函数为
L
(
w
)
≜
1
2
[
q
(
s
t
,
a
t
;
w
)
−
y
t
^
]
2
.
L(\boldsymbol{w}) \triangleq \frac{1}{2}\left[q\left(s_t, a_t ; \boldsymbol{w}\right)-\widehat{y_t}\right]^2 .
L(w)≜21[q(st,at;w)−yt
]2.
做一步梯度下降更新价值网络参数
w
\boldsymbol{w}
w :
w
←
w
−
α
⋅
(
q
^
t
−
y
^
t
)
⋅
∇
w
q
(
s
t
,
a
t
;
w
)
.
\boldsymbol{w} \leftarrow \boldsymbol{w}-\alpha \cdot\left(\widehat{q}_t-\widehat{y}_t\right) \cdot \nabla_{\boldsymbol{w}} q\left(s_t, a_t ; \boldsymbol{w}\right) .
w←w−α⋅(q
t−y
t)⋅∇wq(st,at;w).
训练流程 : 设当前价值网络的参数为
w
now
\boldsymbol{w}_{\text {now }}
wnow , 当前策略为
π
now
\pi_{\text {now }}
πnow 执行以下步骤更新价值网络和策略。
- 用策略网络 π now \pi_{\text {now }} πnow 控制智能体与环境交互, 完成一个回合, 得到轨迹:
s 1 , a 1 , r 1 , s 2 , a 2 , r 2 , ⋯ , s n , a n , r n . s_1, a_1, r_1, s_2, a_2, r_2, \cdots, s_n, a_n, r_n . s1,a1,r1,s2,a2,r2,⋯,sn,an,rn.
- 对于所有的 t = 1 , ⋯ , n − m t=1, \cdots, n-m t=1,⋯,n−m, 计算
q ^ t = q ( s t , a t ; w now ) . \widehat{q}_t=q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) . q t=q(st,at;wnow ).
- 对于所有的 t = 1 , ⋯ , n − m t=1, \cdots, n-m t=1,⋯,n−m, 计算多步 TD 目标和 TD 误差:
y ^ t = ∑ i = 0 m − 1 γ i r t + i + γ m q ^ t + m , δ t = q ^ t − y ^ t . \widehat{y}_t=\sum_{i=0}^{m-1} \gamma^i r_{t+i}+\gamma^m \widehat{q}_{t+m}, \quad \delta_t=\widehat{q}_t-\widehat{y}_t . y t=i=0∑m−1γirt+i+γmq t+m,δt=q t−y t.
- 对于所有的 t = 1 , ⋯ , n − m t=1, \cdots, n-m t=1,⋯,n−m, 对价值网络 q q q 做反向传播, 计算 q q q 关于 w \boldsymbol{w} w 的梯度:
∇ w q ( s t , a t ; w now ) . \nabla_{\boldsymbol{w}} q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) . ∇wq(st,at;wnow ).
- 更新价值网络参数:
w new ← w now − α ⋅ ∑ t = 1 n − m δ t ⋅ ∇ w q ( s t , a t ; w now ) . \boldsymbol{w}_{\text {new }} \leftarrow \boldsymbol{w}_{\text {now }}-\alpha \cdot \sum_{t=1}^{n-m} \delta_t \cdot \nabla_{\boldsymbol{w}} q\left(s_t, a_t ; \boldsymbol{w}_{\text {now }}\right) . wnew ←wnow −α⋅t=1∑n−mδt⋅∇wq(st,at;wnow ).
- 用某种算法更新策略函数 π \pi π 。该算法与 SARSA 算法无关。