上篇文章强化学习——时序差分 (TD) 控制算法 Sarsa 和 Q-Learning主要介绍了 Sarsa 和 Q-Learning 两种时序差分控制算法,在这两种算法内部都要维护一张 Q 表格,对于小型的强化学习问题是非常灵活高效的。但是在状态和可选动作非常多的问题中,这张Q表格就变得异常巨大,甚至超出内存,而且查找效率极其低下,从而限制了时序差分的应用场景。近些年来,随着神经网络的兴起,基于深度学习的强化学习称为了主流,也就是深度强化学习(DRL)。
一、函数逼近介绍
限制 Sarsa 和 Q-Learning 的应用场景原因是需要维护一张巨大的 Q 表格,那么能不能用其他的方式来代替 Q表格呢?很自然的,就想到了函数。
v
^
(
s
,
w
)
≈
v
π
(
s
)
q
^
(
s
,
a
,
w
)
≈
q
π
(
s
,
a
)
π
^
a
,
s
,
w
≈
π
(
a
∣
s
)
\hat{v}(s, w) \approx v_\pi(s) \\ \hat{q}(s,a, w) \approx q_\pi(s, a) \\ \hat{\pi}{a,s,w} \approx \pi(a|s)
v^(s,w)≈vπ(s)q^(s,a,w)≈qπ(s,a)π^a,s,w≈π(a∣s)
也就是说可以用一个函数来代替 Q 表格,不断更新
q
(
s
,
a
)
q(s,a)
q(s,a) 的过程就可以转化为用参数来拟合逼近真实 q 值的过程。这样学习的过程不是更新 Q 表格,而是更新 参数 w 的过程。
下面是几种不同的拟合方式:
第一种函数接受当前的 状态 S 作为输入,输出拟合后的价值函数
第二种函数同时接受 状态 S 和 动作 a 作为输入,输出拟合后的动作价值函数
第三种函数接受状态 S,输出每个动作对应的动作价值函数 q
常见逼近函数有线性特征组合方式、神经网络、决策树、最近邻等,在这里只讨论可微分的拟合函数:线性特征组合和神经网络两种方式。
1、知道真实 V 的函数逼近
对于给定的一个状态 S 假定知道真实的
v
π
(
s
)
v_\pi(s)
vπ(s) ,然后经过拟合得到
v
^
(
s
,
w
)
\hat{v}(s, w)
v^(s,w) ,于是就可以使用均方差来计算损失
J
(
w
)
=
E
π
[
(
v
π
(
s
)
−
v
^
(
s
,
w
)
)
2
]
J(w) = E_\pi[(v_\pi(s) - \hat{v}(s, w))^2]
J(w)=Eπ[(vπ(s)−v^(s,w))2]
利用梯度下降去找到局部最小值:
Δ
w
=
−
1
2
α
∇
w
v
^
(
s
,
w
)
w
t
+
1
=
w
t
+
Δ
w
\Delta w = -\frac{1}{2}\alpha \nabla_w\hat{v}(s,w) \\ w_{t+1} = w_t + \Delta w
Δw=−21α∇wv^(s,w)wt+1=wt+Δw
可以提取一些特征向量来表示当前的 状态 S,比如对于 gym 的 CartPole 环境,提取的特征有推车的位置、推车的速度、木杆的角度、木杆的角速度等

此时价值函数 就可以用线性特征组合表示:
v ^ ( s , w ) = x ( s ) T w = ∑ j = 1 n x j ( s ) ⋅ w j \hat{v}(s,w) = x(s)^Tw=\sum_{j=1}^nx_j(s)\cdot w_j v^(s,w)=x(s)Tw=j=1∑nxj(s)⋅wj
使用均方差计算损失函数:
J
(
w
)
=
E
π
[
(
v
π
(
s
)
−
x
(
s
)
T
w
)
2
]
J(w) = E_\pi[(v_\pi(s) - x(s)^T w)^2]
J(w)=Eπ[(vπ(s)−x(s)Tw)2]
因此更新规则为:
Δ
w
=
α
(
v
π
(
s
)
−
v
^
(
s
,
w
)
)
⋅
x
(
s
)
U
p
d
a
t
e
=
S
t
e
p
S
i
z
e
∗
P
r
e
d
i
c
t
i
o
n
E
r
r
o
r
∗
F
e
a
t
u
r
e
V
a
l
u
e
\Delta w = \alpha(v_\pi(s)-\hat{v}(s,w))\cdot x(s) \\ Update = StepSize\;*\;PredictionError\;*\;FeatureValue
Δw=α(vπ(s)−v^(s,w))⋅x(s)Update=StepSize∗PredictionError∗FeatureValue
二、预测过程中的价值函数逼近
因为函数逼近的就是真实的状态价值,所以在实际的强化学习问题中是没有
v
π
(
s
)
v_\pi(s)
vπ(s) 的,只有奖励。所以在函数逼近过程的监督数据为:
<
S
1
,
G
1
>
,
<
S
2
,
G
2
>
,
⋯
,
<
S
t
,
G
T
>
<S_1, G_1>, <S_2, G_2>, \cdots ,<S_t, G_T>
<S1,G1>,<S2,G2>,⋯,<St,GT>
对于蒙特卡洛有:
Δ
w
=
α
(
G
t
−
v
^
(
s
t
,
w
)
)
∇
w
v
^
(
s
t
,
w
)
=
α
(
G
t
−
v
^
(
s
t
,
w
)
)
⋅
x
(
s
t
)
\Delta w = \alpha({\color{red}G_t} - \hat{v}(s_t, w))\nabla_w\hat{v}(s_t, w) \\ = \alpha({\color{red}G_t} - \hat{v}(s_t, w)) \cdot x(s_t)
Δw=α(Gt−v^(st,w))∇wv^(st,w)=α(Gt−v^(st,w))⋅x(st)
其中奖励
G
t
G_t
Gt 是无偏(unbiased)的:
E
[
G
t
]
=
v
π
(
s
t
)
E[G_t] = v_\pi(s_t)
E[Gt]=vπ(st) 。值得一提的是,蒙特卡洛预测过程的函数逼近在线性或者是非线性都能收敛。
对于TD算法,使用
v
^
(
s
t
,
w
)
\hat{v}(s_t, w)
v^(st,w) 来代替 TD Target。所以在价值函数逼近(VFA)使用的训练数据如下所示:
<
S
1
,
R
2
+
γ
v
^
(
s
2
,
w
)
>
,
<
S
2
,
R
3
+
γ
v
^
(
s
3
,
w
)
>
,
⋯
,
<
S
T
−
1
,
R
T
>
<S_1, R_2+\gamma \hat{v}(s_2, w)>,<S_2, R_3+\gamma \hat{v}(s_3, w)>,\cdots,<S_{T-1}, R_T>
<S1,R2+γv^(s2,w)>,<S2,R3+γv^(s3,w)>,⋯,<ST−1,RT>
于是对于 TD(0) 在预测过程的函数逼近有:
Δ
w
=
α
(
R
t
+
1
+
γ
v
^
(
s
t
+
1
,
w
)
−
v
^
(
s
t
,
w
)
)
∇
w
v
^
(
s
t
,
w
)
=
α
(
R
t
+
1
+
γ
v
^
(
s
t
+
1
,
w
)
−
v
^
(
s
t
,
w
)
)
⋅
x
(
s
)
\Delta w = \alpha({\color{red}R_{t+1} + \gamma \hat{v}(s_{t+1}, w)}-\hat{v}(s_t, w))\nabla_w\hat{v}(s_t, w) \\ = \alpha({\color{red}R_{t+1} + \gamma \hat{v}(s_{t+1}, w)}-\hat{v}(s_t, w))\cdot x(s)
Δw=α(Rt+1+γv^(st+1,w)−v^(st,w))∇wv^(st,w)=α(Rt+1+γv^(st+1,w)−v^(st,w))⋅x(s)
因为TD中的 Target 中包含了预测的
v
^
(
s
,
t
)
\hat{v}(s,t)
v^(s,t) ,所以它对于真实的
v
π
(
s
t
)
v_\pi(s_t)
vπ(st) 是有偏(biased)的,因为监督数据是估计出来的,而不是真实的数据。也就是
E
[
R
t
+
1
+
γ
v
^
(
s
t
+
1
,
w
)
]
≠
v
π
(
s
t
)
E[R_{t+1} + \gamma \hat{v}(s_{t+1}, w)] \neq v_\pi(s_t)
E[Rt+1+γv^(st+1,w)]=vπ(st) 。所以把这个过程叫做 semi-gradient,不是完全的梯度下降,而是忽略了权重向量 w 对 Target 的影响。
三、控制过程中的价值函数逼近
类比于MC 和 TD 在使用 Q 表格时的更新公式,对于策略控制过程可以得到如下公式。和上面预测过程一样,没有真实的 q π ( s , a ) q_\pi(s,a) qπ(s,a) ,所以对其进行了替代:
- 对于 MC,Target 是 G t G_t Gt :
Δ w = α ( G t − q ^ ( s t , a t , w ) ) ∇ w q ^ ( s t , a t , w ) \Delta w = \alpha({\color{red}G_t} - \hat{q}(s_t, a_t, w))\nabla_w\hat{q}(s_t, a_t, w) Δw=α(Gt−q^(st,at,w))∇wq^(st,at,w)
- 对于 Sarsa,TD Target 是 R t + 1 + γ q ^ ( s t + 1 , a t + 1 , w ) R_{t+1} + \gamma \hat{q}(s_{t+1}, a_{t+1}, w) Rt+1+γq^(st+1,at+1,w) :
Δ w = α ( R t + 1 + γ q ^ ( s t + 1 , a t + 1 , w ) − q ^ ( s t , a t , w ) ) ⋅ ∇ w q ^ ( s t , a t , w ) \Delta w = \alpha ({\color{red}R_{t+1} + \gamma \hat{q}(s_{t+1}, a_{t+1}, w)} - \hat{q}{(s_t, a_t, w)})\cdot \nabla_w\hat{q}{(s_t, a_t, w)} Δw=α(Rt+1+γq^(st+1,at+1,w)−q^(st,at,w))⋅∇wq^(st,at,w)
- 对于 Q-Learning,TD Target 是 R t + 1 + γ m a x a q ^ ( s t + 1 , a t , w ) R_{t+1} + \gamma\;max_a\; \hat{q}(s_{t+1}, a_t, w) Rt+1+γmaxaq^(st+1,at,w) :
Δ w = α ( R t + 1 + γ m a x a q ^ ( s t + 1 , a t , w ) − q ^ ( s t , a t , w ) ) ⋅ ∇ w q ^ ( s t , a t , w ) \Delta w = \alpha ({\color{red}R_{t+1} + \gamma\;max_a\; \hat{q}(s_{t+1}, a_t, w)} - \hat{q}{(s_t, a_t, w)})\cdot \nabla_w\hat{q}{(s_t, a_t, w)} Δw=α(Rt+1+γmaxaq^(st+1,at,w)−q^(st,at,w))⋅∇wq^(st,at,w)
四、关于收敛的问题
在上图中,对于使用 Q 表格的问题,不管是MC还是 Sarsa 和 Q-Learning 都能找到最优状态价值。如果是一个大规模的环境,采用线性特征拟合,其中MC 和 Sarsa 是可以找到一个近似最优解的。当使用非线性拟合(如神经网络),这三种算法都很难保证能找到一个最优解。
其实对于off-policy 的TD Learning强化学习过程收敛是很困难的,主要有以下原因:
- 使用函数估计:对于 Sarsa 和 Q-Learning 中价值函数的的近似,其监督数据 Target 是不等于真实值的,因为TD Target 中包含了需要优化的 参数 w,也叫作 半梯度TD,其中会存在误差。
- Bootstrapping:在更新式子中,上面红色字体过程中有 贝尔曼近似过程,也就是使用之前的估计来估计当前的函数,这个过程中也引入了不确定因素。(在这个过程中MC会比TD好一点,因为MC中代替 Target 的 G t G_t Gt 是无偏的)。
- Off-policy 训练:对于 off-policy 策略控制过程中,使用 behavior policy 来采集数据,在优化的时候使用另外的 target policy 策略来优化,两种不同的策略会导致价值函数的估计变的很不准确。
上面三个因素就导致了强化学习训练的死亡三角,也是强化学习相对于监督学习训练更加困难的原因。
下一篇就来介绍本系列的第一个深度强化学习算法 Deep Q-Learning(DQN)
参考资料:
- B站周老师的 强化学习纲要第四节上