SARSA,它是一种 TD 算法,SARSA 的目的是学习动作价值函数
Q
π
(
s
,
a
)
Q_π(s, a)
Qπ(s,a)。
Q
π
Q_π
Qπ 通常被用于评价策略的好坏,而非用于控制智能体。
Q
π
Q_π
Qπ 常与策略函数 π 结合使用,被称作 actor-critic(演员—评委)方法。策略函数 π 控制智能体,因此被看做“演员”;而
Q
π
Q_π
Qπ 评价 π 的表现,帮助改进 π,因此
Q
π
Q_π
Qπ 被看做“评委”。Actor-critic 通常用 SARSA 训练“评委”
Q
π
Q_π
Qπ。
SARSA 算法的推导
SARSA 算法由下面的贝尔曼方程推导出:
用一个神经网络 q(s, a; w) 来近似
Q
π
(
s
,
a
)
Q_π(s, a)
Qπ(s,a)。
给定当前状态
s
t
s_t
st,智能体执行动作
a
t
a_t
at,环境会给出奖励
r
t
r_t
rt 和新的状态
s
t
+
1
s_{t+1}
st+1。然后基于
s
t
+
1
s_{t+1}
st+1 做随机抽样,得到新的动作
a
t
+
1
~
∼
π
(
⋅
∣
s
t
+
1
)
\tilde{a_{t+1}} \sim π(· | s_{t+1})
at+1~∼π(⋅∣st+1)。定义 TD 目标:
y
t
^
≜
r
t
+
γ
⋅
q
(
s
t
+
1
,
a
t
+
1
^
;
w
)
\hat{y_t} ≜ r_t + \gamma \cdot q(s_{t+1}, \hat{a_{t+1}}; w)
yt^≜rt+γ⋅q(st+1,at+1^;w)
我们鼓励
q
(
s
t
,
a
t
;
w
)
q(s_t, a_t; w)
q(st,at;w) 接近 TD 目标
y
t
^
\hat{y_t}
yt^,所以定义损失函数:
L
(
w
)
≜
1
2
[
q
(
s
t
,
a
t
;
w
)
−
y
t
^
]
2
L(w) ≜ \frac{1}{2} [q(s_t, a_t;w) - \hat{y_t}]^2
L(w)≜21[q(st,at;w)−yt^]2
损失函数的变量是 w,而
y
t
^
\hat{y_t}
yt^被视为常数。(尽管
y
t
^
\hat{y_t}
yt^也依赖于参数 w,但这一点被忽略掉。)设
q
t
^
\hat{q_t}
qt^=
q
(
s
t
,
a
t
;
w
)
q(s_t, a_t; w)
q(st,at;w)。损失函数关于 w 的梯度是:
∇
w
L
(
w
)
=
(
q
t
^
−
y
t
^
)
⏟
T
D
误差
δ
t
⋅
∇
w
q
(
s
t
,
a
t
;
w
)
\nabla_{w} L(w) = \underbrace{(\hat{q_t} - \hat{y_t})}_{TD误差\delta_t} \cdot \nabla_{w}q(s_t, a_t; w)
∇wL(w)=TD误差δt
(qt^−yt^)⋅∇wq(st,at;w)
做一次梯度下降更新 w:
w
←
w
−
α
⋅
δ
t
⋅
∇
w
q
(
s
t
,
a
t
;
w
)
w \leftarrow w - \alpha \cdot \delta_t \cdot \nabla_wq(s_t, a_t; w)
w←w−α⋅δt⋅∇wq(st,at;w)
训练流程
设当前价值网络的参数为 w n o w w_{now} wnow,当前策略为 π n o w π_{now} πnow。每一轮训练用五元组 ( s t , a t , r t , s t + 1 , a t + 1 ~ ) (s_t, a_t, r_t, s_{t+1}, \tilde{a_{t+1}}) (st,at,rt,st+1,at+1~) 对价值网络参数做一次更新。
- 观测到当前状态 s t s_t st,根据当前策略做抽样: a t ∼ π n o w ( ⋅ ∣ s t ) a_t ∼ π_{now}(· | s_t) at∼πnow(⋅∣st)。
- 用价值网络计算
(
s
t
,
a
t
)
(s_t, a_t)
(st,at) 的价值:
q t ^ = q ( s t , a t ; w n o w ) \hat{q_t} = q(s_t, a_t; w_{now}) qt^=q(st,at;wnow) - 智能体执行动作 a t a_t at 之后,观测到奖励 r t r_t rt 和新的状态 s t + 1 s_{t+1} st+1。
- 根据当前策略做抽样: a t + 1 ^ ∼ π n o w ( ⋅ ∣ s t + 1 ) \hat{a_{t+1}} ∼ π_{now}(· |s_{t+1}) at+1^∼πnow(⋅∣st+1)。注意, a t + 1 ^ \hat{a_{t+1}} at+1^ 只是假想的动作,智能体不予执行。
- 用价值网络计算
(
s
t
+
1
,
a
t
+
1
^
)
(s_{t+1}, \hat{a_{t+1}})
(st+1,at+1^) 的价值:
q t + 1 ^ = q ( s t + 1 , a t + 1 ^ ; w n o w ) \hat{q_{t+1}} = q(s_{t+1}, \hat{a_{t+1}};w_{now}) qt+1^=q(st+1,at+1^;wnow) - 计算 TD 目标和 TD 误差:
y t ^ = r t + γ ⋅ q t + 1 ^ , δ t = q t ^ − y t ^ \hat{y_t} = r_t + \gamma \cdot \hat{q_{t+1}}, \delta_t = \hat{q_t} - \hat{y_t} yt^=rt+γ⋅qt+1^,δt=qt^−yt^ - 对价值网络 q 做反向传播,计算 q 关于 w 的梯度:
∇ w q ( s t , a t ; w n o w ) \nabla_{w}q(s_t, a_t; w_{now}) ∇wq(st,at;wnow) - 更新价值网络参数:
w n e w ← w n o w − α ⋅ δ t ⋅ ∇ w q ( s t , a t ; w n o w ) w_{new} \leftarrow w_{now} - \alpha \cdot \delta_t \cdot \nabla_wq(s_t, a_t; w_{now}) wnew←wnow−α⋅δt⋅∇wq(st,at;wnow) - 用某种算法更新策略函数。该算法与 SARSA 算法无关