ICLR 2022
Author:Chenjia Bai, Lingxiao Wang, Zhuoran Yang, Zhi-Hong Deng, Animesh Garg, Peng Liu, Zhaoran Wang
Keywords: Pessimistic Bootstrapping, Bootstrapped Q-functions, Uncertainty Estimation, Offline Reinforcement Learning
paper
1. Introduction
离线强化学习旨在利用离线数据,不与环境交互下学习策略。但容易受OOD数据影响导致外推误差。常用解决方法由:1)policy constraint以及2)conservative method。前者限制策略接近行为策略,容易受数据集质量影响;后者则是对OOD状态动作的Q-value值进行惩罚,但容易导致保守的价值估计。
对于OOD数据采用基于模型的不确定性度量被证明有效(MOPO、MOREL),但在复杂环境下模型精确优化困难。
本文提出一种悲观自举的offline RL算法PBRL,基于数据不确定性度量的model-free方法。PBRL通过Bootstrapping Q functions进行价值迭代估计,并将其估计的标准差来量化不确定性,然后将不确定性量化作为惩罚项用于价值函数以及策略优化。除此外,提出一种OOD数据采样技术,作为学习到的Q函数的正则化器。
2. Method
2.1 UNCERTAINTY QUANTIFICATION WITH BOOTSTRAPPING
维持K个bootstrap Q函数用于不确定性估计。其中,第k个Q原始更新目标为:
T
^
Q
θ
k
(
s
,
a
)
:
=
r
(
s
,
a
)
+
γ
E
^
s
′
∼
P
(
⋅
∣
s
,
a
)
,
a
′
∼
π
(
⋅
∣
s
)
[
Q
θ
−
k
(
s
′
,
a
′
)
]
\widehat{\mathcal{T}}Q_\theta^k(s,a):=r(s,a)+\gamma\widehat{\mathbb{E}}_{s'\sim P(\cdot|s,a),a'\sim\pi(\cdot|s)}\Big[Q_{\theta^-}^k(s',a')\Big]
T
Qθk(s,a):=r(s,a)+γE
s′∼P(⋅∣s,a),a′∼π(⋅∣s)[Qθ−k(s′,a′)]
通过K个Q函数的标准差进行不确定性估计
U
(
s
,
a
)
:
=
S
t
d
(
Q
k
(
s
,
a
)
)
=
1
K
∑
k
=
1
K
(
Q
k
(
s
,
a
)
−
Q
ˉ
(
s
,
a
)
)
2
.
\mathcal{U}(s,a):=\mathrm{Std}(Q^k(s,a))=\sqrt{\frac{1}{K}\sum_{k=1}^K\left(Q^k(s,a)-\bar{Q}(s,a)\right)^2}.
U(s,a):=Std(Qk(s,a))=K1k=1∑K(Qk(s,a)−Qˉ(s,a))2.
2.2 PESSIMISTIC LEARNING
对于在离线数据集
D
i
n
D_{in}
Din中数据,将不确定性度量作为惩罚项加入到Q函数的更新中
T
^
i
n
Q
θ
k
(
s
,
a
)
:
=
r
(
s
,
a
)
+
γ
E
^
s
′
∼
P
(
⋅
∣
s
,
a
)
,
a
′
∼
π
(
⋅
∣
s
)
[
Q
θ
−
k
(
s
′
,
a
′
)
−
β
i
n
U
θ
−
(
s
′
,
a
′
)
]
\widehat{\mathcal{T}}^{\mathrm{in}}Q_{\theta}^{k}(s,a):=r(s,a)+\gamma\widehat{\mathbb{E}}_{s^{\prime}\sim P(\cdot|s,a),a^{\prime}\sim\pi(\cdot|s)}\Big[Q_{\theta^{-}}^{k}(s^{\prime},a^{\prime})-\beta_{\mathrm{in}}\mathcal{U}_{\theta^{-}}(s^{\prime},a^{\prime})\Big]
T
inQθk(s,a):=r(s,a)+γE
s′∼P(⋅∣s,a),a′∼π(⋅∣s)[Qθ−k(s′,a′)−βinUθ−(s′,a′)]
而对于OOD的数据,PBRL首先从
D
i
n
D_{in}
Din采样OOD states,然后由当前策略
π
(
⋅
∣
s
O
O
D
)
\pi(\cdot | s^{OOD})
π(⋅∣sOOD)得到OOD action, 这部分数据的Q更新如下:
T
^
o
o
d
Q
θ
k
(
s
o
o
d
,
a
o
o
d
)
:
=
Q
θ
k
(
s
o
o
d
,
a
o
o
d
)
−
β
o
o
d
U
θ
(
s
o
o
d
,
a
o
o
d
)
\hat{\mathcal{T}}^{\mathrm{ood}}Q_\theta^k(s^{\mathrm{ood}},a^{\mathrm{ood}}):=Q_\theta^k(s^{\mathrm{ood}},a^{\mathrm{ood}})-\beta_{\mathrm{ood}}\mathcal{U}_\theta(s^{\mathrm{ood}},a^{\mathrm{ood}})
T^oodQθk(sood,aood):=Qθk(sood,aood)−βoodUθ(sood,aood)
算法实现中,引入一个额外的截断稳定早期训练过程:
max
{
0
,
T
o
o
d
Q
θ
k
(
s
o
o
d
,
a
o
o
d
)
}
.
\max\{0,\mathcal{T}^{\mathrm{ood}}Q_{\theta}^{k}(s^{\mathrm{ood}},a^{\mathrm{ood}})\}.
max{0,ToodQθk(sood,aood)}.
β
\beta
β是重要超参,在初始阶段的不确定估量不准确,因此采用较大值对Q函数保守估计,而随着训练的进行,不确定性估量逐渐稳定准确,
β
\beta
β减小
综上两种迭代方法,对Critic的更新函数如下:
L
c
r
i
t
i
c
=
E
^
(
s
,
a
,
r
,
s
′
)
∼
D
m
[
(
T
^
i
n
Q
k
−
Q
k
)
2
]
+
E
^
s
o
o
d
∼
D
i
n
,
a
o
o
d
∼
π
[
(
T
^
o
o
d
Q
k
−
Q
k
)
2
]
,
\mathcal{L}_{\mathrm{critic}}=\widehat{\mathbb{E}}_{(s,a,r,s^{\prime})\sim\mathcal{D}_{\mathrm{m}}}\big[(\widehat{\mathcal{T}}^{\mathrm{in}}Q^{k}-Q^{k})^{2}\big]+\widehat{\mathbb{E}}_{s^{\mathrm{ood}}\sim\mathcal{D}_{\mathrm{in}},a^{\mathrm{ood}}\sim\pi}\big[(\widehat{\mathcal{T}}^{\mathrm{ood}}Q^{k}-Q^{k})^{2}\big],
Lcritic=E
(s,a,r,s′)∼Dm[(T
inQk−Qk)2]+E
sood∼Din,aood∼π[(T
oodQk−Qk)2],
对于policy的更新目标为
π
φ
:
=
max
φ
E
^
s
∼
D
i
n
,
a
∼
π
(
⋅
∣
s
)
[
min
k
=
1
,
…
,
K
Q
k
(
s
,
a
)
]
\pi_\varphi:=\max_\varphi\widehat{\mathbb{E}}_{s\sim\mathcal{D}_{\mathrm{in}},a\sim\pi(\cdot|s)}\Big[\min_{k=1,\ldots,K}Q^k(s,a)\Big]
πφ:=φmaxE
s∼Din,a∼π(⋅∣s)[k=1,…,KminQk(s,a)]