STORM:为强化学习基于高效随机Transformer的世界模型

23年10月来自北理工和清华的论文“STORM: Efficient Stochastic Transformer based World Models for Reinforcement Learning”。

最近,基于模型的强化学习算法,在视觉输入环境中表现出了显著的效果。这些方法首先通过自监督学习构建真实环境的参数化模拟世界模型。通过利用世界模型的想象力,智体的策略得到了增强,而不受从真实环境中采样的限制。这些算法的性能在很大程度上依赖于世界模型的序列建模和生成能力。然而,构建一个复杂未知环境的完美准确模型几乎是不可能的。模型和现实之间的差异可能会导致智体追求虚拟目标,而在真实环境中的表现不佳。在基于模型的强化学习中引入随机噪声已被证明是有益的。在这项工作中,引入了基于随机 Transformer 的世界模型 (STORM),这是一种高效的世界模型架构,它将 Transformer 强大的序列建模和生成能力与变分自动编码器的随机性相结合。 STORM 在 Atari 100k 基准测试中达到人类平均水平的 126.7%,创下不使用前瞻搜索技术的先进方法新纪录。此外,在单个 NVIDIA GeForce RTX 3090 显卡上训练具有 1.85 小时实时交互经验的智体仅需 4.3 小时,与之前的方法相比,效率有所提高。

深度强化学习 (DRL) 已在不同领域取得了显著成功。然而,实现这种成功需要与环境进行大量交互,这阻碍了它在现实环境中的广泛应用。当处理更广泛的现实环境(例如缺乏可调速度模拟工具的无人驾驶和制造系统 [1, 2])时,这种限制变得尤为具有挑战性。因此,提高样本效率已成为 DRL 算法面临的一个关键挑战。

流行的 DRL 方法(包括 Rainbow [3] 和 PPO [4])由于两个主要原因而存在样本效率低下的问题。首先,价值函数的估计是一项具有挑战性的任务。这涉及使用深度神经网络 (DNN) 近似价值函数,并使用 n 步引导时间差(TD)对其进行更新,这自然需要多次迭代才能收敛 [5]。其次,在奖励稀疏的场景中,许多样本在价值函数方面表现出相似性,为 DNN 的训练和泛化提供有限的有用信息 [6, 7]。这进一步加剧了提高 DRL 算法样本效率的挑战。

为了应对这些挑战,基于模型的 DRL 算法应运而生,成为一种有前途的方法,它可以同时解决这两个问题,同时在样本效率高的环境中表现出显著的性能提升。这些算法首先通过自监督学习构建真实环境的参数化模拟世界模型。自监督学习可以通过多种方式实现,例如使用解码器重建原始输入状态 [8–10]、预测帧之间的动作 [7] 或采用对比学习来捕捉输入状态的内部一致性 [6, 7]。这些方法比传统的无模型 RL 损失提供了更多的监督信息,增强了 DNN 的特征提取能力。随后,通过利用使用世界模型生成的经验来改进智体的策略,消除采样约束,并与无模型算法相比更快地更新价值函数。

然而,使用世界模型进行想象的过程,涉及一个自回归过程,该过程会随着时间的推移累积预测误差。当想象轨迹与真实轨迹出现差异时,智体可能会无意中追求虚拟目标,导致在真实环境中的表现不佳。为了缓解这个问题,在世界模型中引入随机噪声已被证明是有益的 [9–11, 14]。变分自编码器能够自动学习高维数据的低维潜表示,同时将合理的随机噪声纳入潜空间,为图像编码提供了理想的选择。

人们已经进行了大量努力来构建高效的世界模型。例如,SimPLe [11] 利用 LSTM [15],而 DreamerV3 [10] 采用 GRU [16] 作为序列模型。LSTM 和 GRU 都是循环神经网络 (RNN) 的变型,在序列建模任务方面表现出色。然而,RNN 的循环性质阻碍并行计算,导致训练速度变慢 [17]。相比之下,Transformer 架构 [17] 最近在各种序列建模和生成任务中都表现出优于 RNN 的性能。它克服了忘记长期依赖关系的挑战,专为高效的并行计算而设计。虽然已经进行了多次尝试将 Transformer 纳入世界模型 [12、13、18],但这些工作并未充分利用该架构的功能。此外,这些方法需要更长的训练时间,并且无法超越基于 GRU 的 DreamerV3 的性能。

基于随机 Transformer 的世界模型 (STORM),是一种基于模型的 RL 的高效结构。世界模型最近的方法和STORM的比较:
SimPLe [11] 和 Dreamer [10] 依赖于基于 RNN 的模型,而 STORM 采用类似 GPT 的 Transformer [30] 作为序列模型。
与使用多个 token 的 IRIS [13] 相比,STORM 使用单个随机潜变量来表示图像。
STORM 遵循一个 vanilla Transformer[17] 结构,而 TWM[12] 采用 Transformer-XL [21] 结构。

在 STORM 的序列模型中,观察和动作融合为单个 token,而 TWM [12] 将观察、动作和奖励视为三个同等重要的独立 token。
与包含隐藏状态的 Dreamer [10] 和 TransDreamer [18] 不同,STORM 无需利用此信息即可重建原始图像。

如图所示 STORM 和其他方法在 Atari 100k 上的方法比较:SimPLe [11] 和 DreamerV3 [10] 使用 RNN 作为其世界模型,而 TWM [12]、IRIS [13] 和 STORM 使用 Transformer;单个 NVIDIA V100 GPU 上的每秒训练帧数 (FPS) 结果是从 SimPLe、TWM 和 IRIS 的其他显卡推断出来的,而 DreamerV3 和 STORM 则是直接评估的。

请添加图片描述

下表STORM与其他方法的比较:

请添加图片描述

该方法遵循基于模型强化学习算法的既定框架,该算法专注于通过想象力增强智体的策略 [5、9-11、13]。迭代以下步骤,直到达到规定的真实环境交互次数:
S1) 执行当前策略的几个步骤来收集真实环境数据,并将其加到重放缓冲区。
S2) 用从重放缓冲区采样的轨迹更新世界模型。
S3) 用由世界模型生成的想象经验改进策略,其中想象过程的起点从重放缓冲区采样。

在每个时间 t,数据点包含一个观察 ot、一个动作 at、一个奖励 rt 和一个延续标志 ct(一个布尔变量,指示当前情节是否正在进行)。重放缓冲区维护先进先出的队列结构,从而能够从缓冲区中采样连续的轨迹。

该世界模型的完整结构如图所示。实验中,专注于 Atari 游戏 [31],它会生成环境的图像观测 ot。直接在原始图像上对环境动态进行建模计算成本高昂且容易出错 [7–11, 13, 23]。

请添加图片描述

为了解决这个问题,利用 VAE [32] 将 ot 转换为潜随机分类分布 Zt。与之前的研究 [9, 10, 12] 一致,将 Zt 设置为包含 32 个类别的随机分布,每个类别有 32 个级。编码器 (qφ) 和解码器 (pφ) 结构实现为卷积神经网络 (CNN) [33]。随后,从 Zt 中采样一个潜变量 zt 来表示原始观测 ot。由于从分布中采样缺少用于后向传播的梯度,应用直通梯度技巧 [9, 34] 来保留它们。

请添加图片描述

在进入序列模型之前,用多层感知器 (MLP) 和连接(concatenation)将潜样本 zt 和动作 at 组合成单个 token et。此操作表示为 mφ,为序列模型准备输入。序列模型 fφ 将 et 的序列作为输入并产生隐藏状态 ht。对序列模型采用类似 GPT Transformer 结构 [30],其中自注意模块被后续掩码所掩盖,从而允许 et 关注序列 e1, e2, …, et。通过利用 MLP gφD、gφR 和 gφC,依靠 ht 来预测当前奖励 rtˆ、延续标志 ctˆ 和下一个分布 Zˆt+1。这部分世界模型的公式如下

请添加图片描述

智体的学习完全基于世界模型促进的想象过程,如图所示。为了启动想象过程,从重放缓冲区中随机选择一个简短的上下文轨迹,并计算初始后验分布 Zt。在推理过程中,不是直接从后验分布 Zt 中采样,而是从先验分布 Zt 中采样 zt。为了加速推理,在 Transformer 结构中采用 KV 缓存技术 [35]。

智体的状态由 zt 和 ht 连接而成,如下所示:

请添加图片描述

采用 DreamerV3 [10] 中的 Actor 学习设置 (AC学习算法)。

Atari 100k 包含 26 种不同的视频游戏,离散动作维度高达 18。100k 样本约束对应于 400k 个实际游戏帧,其中考虑跳帧(跳过 4 帧)和这些帧内的重复动作。此约束对应于大约 1.85 小时的实时游戏时间。智体的人类玩家归一化分数 τ = (A−R)/ (H−R)是根据智体获得的分数 A、随机策略获得的分数 R 以及人类玩家在特定环境中获得的平均分数 H 计算得出的。为了确定人类玩家的表现 H,玩家可以在相同的样本约束下熟悉游戏。

为了证明提出的世界模型结构的效率,将其与共享类似训练流程的基于模型 DRL 算法进行了比较。但是,与 [10、12、13] 类似,不会直接将结果与 MuZero [23] 和 EfficientZero [7] 等前瞻搜索方法进行比较,因为主要目标是改进世界模型本身。尽管如此,未来可以将前瞻搜索技术与该方法相结合,以进一步提高智体的性能。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值