Offline RL :Bootstrapped Transformer for Offline Reinforcement Learning

NIPS 2022
paper
code
可看作是一种数据增强

Intro

最近的一些工作通过将离线 RL 视为一种通用的序列生成问题,并采用诸如 Transformer 架构的序列模型来模拟轨迹上的分布。然而,一般离线 RL 任务中使用的训练数据集通常非常有限,并经常因分布覆盖不足而受到影响,这可能对训练序列生成模型不利。

为此,作者提出了 Bootstrapped Transformer 算法,该算法结合了自举(bootstrapping)的思想,利用学习到的模型生成更多的离线数据,以进一步增强序列模型的训练。通过在两个离线 RL 基准上的广泛实验,作者证明了他们的模型可以大幅弥补现有的离线 RL 训练限制,并超越其他强基线方法。

Method

首先根据序列数据添加累计回报 R t = ∑ t ′ = t T γ t ′ − t r t ′ R_{t}=\sum_{t^{\prime}=t}^{T}\gamma^{t^{\prime}-t}r_{t^{\prime}} Rt=t=tTγttrt, 并对原始数据处理成离散的token space
τ = τ dis = ( … , s t 1 , s t 2 , … , s t N , a t 1 , a t 2 , … , a t M , r t , R t , … ) . \tau =\tau_{\text{dis}}=\begin{pmatrix}\ldots,s_t^1,s_t^2,\ldots,s_t^N,a_t^1,a_t^2,\ldots,a_t^M,r_t,R_t,\ldots\end{pmatrix}. τ=τdis=(,st1,st2,,stN,at1,at2,,atM,rt,Rt,).
接下来采用TT架构,通过最大化似然函数 L \mathcal{L} L优化序列模型
log ⁡ P θ ( τ t ∣ τ < t ) = ∑ i = 1 N log ⁡ P θ ( s t i ∣ s t < i , τ < t ) + ∑ j = 1 M log ⁡ P θ ( a t j ∣ a t < j , s t , τ < t ) + log ⁡ P θ ( r t ∣ a t , s t , τ < t ) + log ⁡ P θ ( R t ∣ r t , a t , s t , τ < t ) L ( τ ) = ∑ t = 1 T log ⁡ P θ ( τ t ∣ τ < t ) , \begin{aligned} \log P_{\theta}(\tau_{t}|\tau_{<t})& =\sum_{i=1}^{N}\log P_{\theta}(s_{t}^{i}|s_{t}^{<i},\tau_{<t})+\sum_{j=1}^{M}\log P_{\theta}(a_{t}^{j}|a_{t}^{<j},s_{t},\tau_{<t}) \\ &+\log P_\theta(r_t|\boldsymbol{a}_t,s_t,\tau_{<t})+\log P_\theta(R_t|r_t,\boldsymbol{a}_t,s_t,\tau_{<t})\\ \mathcal{L}(\tau)&=\sum_{t=1}^T\log P_\theta(\tau_t|\tau_{<t}), \end{aligned} logPθ(τtτ<t)L(τ)=i=1NlogPθ(stist<i,τ<t)+j=1MlogPθ(atjat<j,st,τ<t)+logPθ(rtat,st,τ<t)+logPθ(Rtrt,at,st,τ<t)=t=1TlogPθ(τtτ<t),

Trajectory Generation

接下俩便是利用模型生成轨迹实现数据增强。文章这里提出了两种增强方法

  1. Autoregressive generation: y ~ n ∼ P θ ( y n ∣ y ~ < n , τ ≤ T − T ′ ) \tilde{y}_n\sim P_\theta\left(y_n|\tilde{y}_{<n},\tau_{\leq T-T'}\right) y~nPθ(yny~<n,τTT)。通过自回归的方法生成各个token。
    在这里插入图片描述

  2. Teacher-forcing generation: y ~ n ∼ P θ ( y n ∣ y < n , τ ≤ T − T ′ ) \tilde{y}_n\sim P_\theta\left(y_n|{y}_{<n},\tau_{\leq T-T'}\right) y~nPθ(yny<n,τTT),正常序列生成
    在这里插入图片描述
    增强后的tokens数据将于原始数据concatenate共同训练序列模型

为了防止由于训练数据不准确而导致的累积学习偏差,根据生成百分比 η% 选择每批中置信度分数最高的部分轨迹。置信度定义为所有生成的令牌的平均对数概率为:
c ( τ ) = 1 T ′ ( N + M + 2 ) ∑ t = T − T ′ + 1 T log ⁡ P θ ( τ t ∣ τ < t ) c(\tau)=\frac{1}{T'(N+M+2)}\sum_{t=T-T'+1}^T\log P_\theta(\tau_t|\tau_{<t}) c(τ)=T(N+M+2)1t=TT+1TlogPθ(τtτ<t)

伪代码

在这里插入图片描述
算法提出两种Bootstrapp形式。这两种方法都将首先在原始的离线轨迹上训练序列模型,然后利用它生成新的轨迹。但是对增强数据的利用方式不同

  1. Boot-o: 将再次在生成的轨迹上训练模型,并在使用它们后立即丢弃它们
  2. Boot-r:将生成的轨迹附加到原始数据集中,生成的轨迹将在附加到数据集之后的每个时期使用。

实验结果看第一种结合自回归的增强稍微好点

results

对比其他Offline的方法
在这里插入图片描述
对比两种增强方法
在这里插入图片描述

  • 10
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值