这篇论文前两个月刚刚放出,研究了如何让人工智能(AI)更好地解决复杂的规划问题,比如在迷宫中寻找最短路径,或者推箱子游戏(Sokoban)中把箱子全部推到指定位置。
传统上,这类问题通常使用专门的规划算法来解决,比如A*搜索算法。但是,训练AI模型(如Transformer)来解决这些问题仍然很有挑战性。
这篇论文提出了一种叫做"搜索动态自举"(Search Dynamics Bootstrapping)的方法来训练Transformer模型:
-
首先,它训练一个Transformer模型来模仿A算法的搜索过程。也就是说,模型不仅学习预测最终的解决方案,还学习预测A算法搜索问题空间的中间步骤。这样,模型就学会了像A*算法一样"思考"。
-
接下来,通过微调这个训练好的模型,让它生成更短的搜索过程,同时仍然能找到最优解。这一步骤使模型变得更加高效。
通过这种方法,论文训练出一个叫做Searchformer的Transformer模型。实验表明,Searchformer在解决迷宫和Sokoban问题上,性能优于原始的A*算法。它能以更少的搜索步骤找到最优解。
这项工作的意义在于:
- 展示了如何让Transformer模型学会解决复杂的规划问题。
- 提出了一种新的训练方法(搜索动态自举),可以提高模型的效率。
- 为探索AI模型的推理和规划能力开辟了新的方向。
总之,这篇论文为训练AI模型解决规划问题提供了新的思路和方法,有助于开发出更加智能高效的AI系统,下面详细解读。
通过搜索动态自举实现超越A*的更好规划:使用Transformers进行规划
Lucas Lehnert, Sainbayar Sukhbaatar, DiJia Su, Qinqing Zheng, Paul Mcvay, Michael Rabbat, Yuandong Tian
Facebook人工智能研究院
摘要
尽管Transformers在各种应用场景中取得了巨大进展,但在解决复杂决策任务方面,此类架构仍然落后于传统的符号规划器。在这项工作中,我们展示了如何训练Transformers来解决复杂的规划任务。这是通过训练一个编码器-解码器Transformer模型来预测A搜索算法的搜索动态来实现的。我们微调这个模型,以获得一个Searchformer,一个Transformer模型,它能以93.7%的概率最优地解决之前未见过的Sokoban难题,同时使用的搜索步骤比最初用于训练的A实现少26.8%。在我们的训练方法中,A*的搜索动态被表示为一个标记序列,概述了在符号规划期间何时将任务状态添加到搜索树和从搜索树中移除。Searchformer以5-10倍更小的模型规模和10倍更小的训练数据集明显优于直接预测最优计划的基线。最后,我们展示了Searchformer如何扩展到更大、更复杂的决策任务,提高了解决任务的百分比并缩短了搜索动态。
1. 引言
基于Transformer的架构在不同任务中表现出令人印象深刻的性能,包括人类水平的对话、高质量的图像理解和视频生成、多模态生成以及代码补全。通过在互联网规模的数据集上训练这些架构,得到的模型(如大语言模型)可以在现实世界的应用中很好地泛化。
尽管取得了这些成功,但基于Transformer的架构和LLMs在解决规划和推理任务方面仍然存在困难。先前的研究表明,LLMs在多步规划任务或执行高阶推理时表现不佳。
近年来,已经提出了各种方法来提高Transformers在这些设置中的性能。一种方法是模拟人类的思维过程,并在输出响应之前产生中间"思考"。Chain-of-Thought(CoT)提示和Tree-of-thoughts(ToT)方法鼓励模型一步一步地"思考"。虽然这些技术通常是有效的,但它们也可能导致更差的性能,例如由于自我强化。此外,在一个数据集上有效的技术可能不能很好地应用于其他数据集,因为所涉及的推理类型发生了变化(例如,空间推理与数学推理)。如何使Transformers和LLMs能够规划、解决多步决策任务以及执行推理仍然难以捉摸,是一个活跃的研究领域。
我们的工作
我们展示了如何训练Transformers来稳健地解决复杂的规划任务。与LLMs类似,我们训练Transformers根据一系列单词预测下一个单词。我们的实验使用具有合成语言和词汇表的合成生成数据集。使用这个框架,我们演示了如何构建训练数据,使得生成的模型模仿A搜索执行的计算。最后,我们提出了Searchformer,一个Transformer模型,它以比我们的A参考实现更少的搜索步骤解决复杂的规划任务。该模型是通过搜索动态自举获得的,这是一种首先训练Transformer模仿A*的搜索过程,然后微调模型以在更少的搜索步骤内找到最优计划的方法。
为了训练Transformer执行规划,我们将规划任务及其最优解计划表示为一系列单词,称为标记。我们还将A执行的计算记录到由标记序列组成的执行跟踪中,从而生成一个捕获A搜索动态的序列数据集。使用这些搜索增强的序列,训练一个Transformer模型生成编码A*搜索动态以及最优计划的标记序列。
随后,将得到的搜索增强模型进行微调,以生成更短的标记序列,同时仍然输出最优计划。我们将这个最终微调的模型称为Searchformer。在解决Sokoban难题时,我们的模型在93.7%的所有测试任务中找到最优解,同时执行的搜索步骤比最初用于训练的A*实现平均少26.8%。
通过一系列控制任务复杂性、数据集大小和模型大小的实验,我们证明了将执行跟踪包含到训练数据中可以提高独立测试任务集的性能——尽管生成的序列长度增加了10到100倍。我们发现,搜索增强模型(将执行跟踪包含到其训练数据中)在未见过的任务上生成最优计划的频率比更大的解决方案模型(仅在包含任务描述和任务解决方案的序列上训练)高10倍,而训练序列数量要少10倍。这一结果突显了将A*的搜索动态纳入Transformer模型训练过程的威力。
2. 相关工作
虽然现有工作利用合成数据集来学习推理策略,但我们的研究在这方面有根本的不同。我们专注于提高嵌入在Transformer权重中的推理能力。现有的算法(如AlphaZero、MuZero和AlphaGeometry)使用现有符号规划算法的输出(其内部状态未被使用,即被视为黑盒)来优化神经网络。例如,Silver等人使用MCTS作为策略改进算子来更新神经网络的权重。相比之下,我们提出的搜索动态自举方法使用Transformer模型泛化到更高效的搜索模式,并改进模型本身。规划算法(连同其内部搜索动态)被用于最初训练Transformer模型。
之前的工作侧重于在推理任务的执行跟踪上训练神经网络,或者训练神经网络预测最优动作。相比之下,我们侧重于训练Transformer生成A*计算最优计划时的整个搜索过程。我们的模型不仅预测单个动作,还预测解决任务的整个多步计划。
我们的工作与神经符号系统有一些相似之处,它们构建可微分的架构来模拟现有符号系统的功能。然而,这些方法使用专用组件(例如,显式内存组件,内置递归),而Searchformer侧重于下一个标记的预测。在这里,Searchformer依赖于生成长上下文和位置嵌入来预测最优计划。最终,我们的工作阐明了如何构建自动学习规划机制的更通用的架构。
在强化学习(RL)设置中,先前的工作研究了使用Transformer架构来解决复杂的顺序决策任务。然而,这项先前工作提出了对试错交互轨迹建模的不同方法,并侧重于预测下一个动作、状态或奖励或它们的组合。相比之下,我们演示了如何使用Transformer来模拟计算最优多步计划所涉及的搜索步骤。
MCTSNet也试图学习搜索过程本身,但仍然将MCTS搜索过程硬编码到神经网络中,这导致二次反向传播开销,只能处理最多500步的展开,而我们的方法可以处理更长的搜索执行跟踪。我们证明了Transformers不仅可以模仿符号规划算法,还可以通过微调来发现更有效的启发式方法。
3. 问题设置
图1概述了我们的合成数据集生成过程。我们考虑两个领域:迷宫导航(图1(a))和求解Sokoban难题(图5)。在迷宫导航中,目标是找到通过n×n迷宫的最短路径。在Sokoban中,工人可以向上、向下、向左或向右移动,必须将每个箱子推到码头上才能解决难题。错误的移动可能会立即导致死胡同,因此需要跨多个时间步骤进行推理才能解决难题。难题中的每个状态都由箱子和工人位置的组合构成,这使得Sokoban在计算上比迷宫导航更难解决。
3.1 生成A*搜索的执行跟踪
A*算法通过操纵两组节点来计算最优计划:
- 一个前沿集合,包含当前的搜索前沿。
- 一个封闭集合,包含所有搜索过的节点。
在图1(a)的迷宫示例中,每个节点对应一个空的(非墙)网格单元。对于每个节点,该算法计算启发式值和从起点开始的代价。在任何给定的迭代中,接下来搜索哪个节点取决于前沿集合和封闭集合的内容以及启发式值和从起点开始的代价值(图1©,左面板)。A的执行跟踪是通过跟踪插入前沿集合和封闭集合的所有操作以及启发式和从起点开始的代价值来收集的(图1©,右面板)。图1©的右面板说明了图1(b)所示迷宫示例的结果跟踪。每一行对应于将节点插入前沿集合(由create标记表示)或将节点移动到封闭集合(由close标记表示)。每个节点由其在迷宫中的(x,y)位置以及两个代价标记表示。然后将生成的计划附加到此跟踪中。构建此跟踪,使得给定任何前缀都可以正确预测下一个标记。对于迷宫数据集,A使用到目标位置的曼哈顿距离作为启发式。在Sokoban中,A*首先将每个箱子与最近的码头匹配,然后计算每对箱子和码头之间所有曼哈顿距离的总和。
对于每个实验,我们生成两种不同的令牌序列变体,如图1所示:
- 解决方案序列,格式为,其中部分对任务描述进行编码,部分对最优计划进行编码(图1(b))。
- 搜索增强序列,格式为,其中部分对A*的执行跟踪进行编码(图1©)。
因为每个模型都是从头开始训练的,所以生成的模型被专门训练来只预测概述一组不同规划任务的最优计划的序列。训练后,如果模型的输出包含最优或可行的解决方案计划,则对其进行解析和评估。
3.2 训练Transformer模型
在生成令牌序列数据集时,每个任务都是唯一的,并且构建测试集,使其不包含训练集的任何重复。通过这个实验设计,我们希望了解如何使用Transformers来解决规划任务并泛化到以前未见过的测试任务。
通过包含中间计算步骤,Transformer模型被训练以有效地模仿A*算法执行的计算。与Procedure Cloning(其中学习一个神经网络来预测最优状态/动作序列,在我们的例子中是任务提示和最优计划)不同,我们的Transformer模型还学习预测导致最优计划的整个思考过程,包括尝试但失败的路径。
对于每个实验,我们使用集成了Rotary Position Embeddings (RoPE)的编码器-解码器T5架构的改编版本。更多细节和超参数可以在附录B中找到。编码器处理训练序列的部分,解码器处理格式的序列(搜索增强模型)或仅格式的序列(仅解决方案模型)。根据模型变体,每个网络被训练以最大化解码器生成的分布与从训练数据集中采样相应序列的分布之间的交叉熵。附录A更详细地描述了我们的优化设置。
3.3 通过搜索动态自举超越算法模仿
为了减少搜索增强模型在推理期间生成的标记数量,我们实现了一种方法来改变解码器生成执行跟踪的分布。首先,训练一个搜索增强模型以模仿A搜索的搜索动态。为了使用这个搜索增强模型发现新的搜索动态并探索执行跟踪空间,搜索增强模型必须为同一任务提示生成不同的序列。我们通过在训练数据中引入非确定性并使用非确定性A实现来实现这一点,该实现随机打破代价平局并随机化扩展子节点的顺序。这种方法不会降低A搜索本身的效率,仅仅改变了搜索不同节点的顺序,同时仍然遵循A的启发式和代价计算。生成的搜索增强模型将近似生成训练序列的概率分布。
一旦模型被训练以模仿非确定性A*搜索的搜索动态,它就被用来生成一个由更短的标记序列组成的新训练数据集。通过使用训练好的搜索增强模型对每个训练提示采样多个不同的标记序列来构建这个新数据集。在这一步中,我们只使用训练数据集进行自举,而不使用测试数据集。解析每个生成的序列并检查它是否以最优计划结束。如果是这种情况,并且序列也比原始训练数据集中包含的相应序列更短,则将这个缩短的序列包含在新的短序列训练数据集中。如果生成的序列没有以最优计划结束或者比原始训练序列更长,则重新使用原始训练数据集中的序列。
随后,在新的短序列训练数据集上微调搜索增强模型。为了与模仿A搜索动态的搜索增强模型区分开来,我们将这个新模型称为Searchformer。然后可以通过使用得到的微调模型生成下一个更短的序列数据集,然后再次微调Searchformer模型来重复此过程。在第4.3节中,我们证明了这个过程确实减少了推理期间执行的步骤数量,同时进一步提高了性能。Searchformer模型不再模仿A搜索,而是发现了一种使用更少步骤解决规划问题的新方法。
4. 实验
在我们的实验中,我们使用两种不同的A*实现来生成序列数据:
-
确定性A数据集:通过以确定性方式执行A(通过确定性地排序子节点并打破相等代价平局)生成序列。因此,给定一个任务提示,最优计划和A执行跟踪是唯一的。在这里,Transformer隐式地学习数据中编码的确定性打破规则。评估这样一个模型很简单,因为生成的序列需要与A生成的序列完全匹配。
-
非确定性A数据集:通过以非确定性方式执行A(通过随机排序子节点并随机打破相等代价平局)生成序列。因此,给定一个任务提示,最优计划和A执行跟踪不再是唯一的,并且有多个正确的响应。在这里,Transformer学习生成隐式编码在序列数据中的随机平局打破规则。因此,不同执行之间生成的序列有所不同,但生成的计划仍然是最优的,执行跟踪仍然遵循第3.3节中描述的A的代价和启发式计算。
图7概述了每个数据集的标记序列长度,并显示生成的A*执行跟踪的长度随着任务复杂性的增加而增长。图8显示训练集和测试集在难度上相匹配,并且具有可比的跟踪长度。对于每个任务,一个模型可能生成一个以最优计划、可行计划(正确但次优的计划)或无效计划结尾的搜索序列。在附录D中,我们概述了如何评分每个模型预测可行计划和最优计划的能力,以及如何评估搜索增强模型和Searchformer模型的搜索动态的细节。
除非另有说明,每个实验重复五次,每个图绘制所有重复的平均值。所有报告的误差表示测量的标准误差(SEM)。
4.1 迷宫导航
在第一组实验中,我们训练一组编码器-解码器Transformer模型来预测迷宫导航任务的最优计划。我们在不同的训练运行之间改变训练数据集大小和模型大小(优化参数的数量),并在使用相同超参数生成的测试任务上评估每个模型。
确定性A*
图2(a)绘制了对于测试任务生成正确响应的百分比。解决方案模型和搜索增强模型都在确定性A数据集上训练,并评估它们是否完全复现A搜索生成的标记序列(请参考附录D中的精确匹配标准)。可以观察到,解决方案模型的性能远不如大多数搜索增强模型。只有对于足够大的训练数据集,解决方案模型才能匹配最差的搜索增强模型的性能。在低训练数据机制下(100,000个训练序列及以下),解决方案模型的性能显著下降,而每个搜索增强模型的性能保持相对较高。
这个结果令人惊讶,因为对于90%以上的测试迷宫,搜索增强模型生成数千个标记长的格式序列,而没有预测任何单个标记不正确。而解决方案模型平均预测的序列短10倍,却明显不如搜索增强模型。即使是最小的搜索增强模型也明显优于参数更多的解决方案模型。
这个结果突出了训练Transformers生成长算法执行跟踪的威力。我们没有观察到通常限制基于深度模型的RL agent的复合预测错误,因为使用的反向因果解码器网络为n步序列构造n×n的注意力图。在这里,Transformer架构的这个特性被用来在预测最优计划时提高性能。
非确定性A*
当在非确定性A*数据上训练时,模型可以为一个任务输出多个不同的最优路径。在这里,我们使用每个模型为每个任务生成64个标记序列。如果64个序列中的任何一个包含最优计划,则测试任务被算作正确回答(请参考附录D中的任何最优64标准)。因为我们只测试至少一个生成的序列是否包含最优计划,所以我们在图2(b)中获得了比图2(a)更高的绝对数字。
图2(b)绘制了生成64个标记序列时找到最优计划的测试任务的百分比。在这里,我们可以观察到与确定性A*数据集类似的模式:即使是最小的搜索增强模型也优于解决方案模型,特别是对于小型训练集。此外,我们发现模型大小只在使用非常小的训练数据集(50,000个训练序列)时影响每个搜索增强模型的性能。对于更大的训练数据集大小,没有发现显著差异。增加解决方案模型的参数数量并不能显著提高它们在低数据机制下的性能(图9)。
不同任务难度级别下的性能
最后,图2©说明了任务的难度如何影响每个模型的性能。在这里,我们再次关注非确定性A*生成的数据集,并考虑迷宫大小的函数正确解决的测试任务数量。迷宫越大,任务的状态空间就越大,找到最优解决方案计划所需的计算就越多。虽然解决方案模型的性能随着任务变得更具挑战性而迅速下降,但搜索增强模型保持相对较高的准确性,即使对于其最小的模型大小也是如此。附录F给出了所有迷宫大小的完整比较。
总的来说,虽然解决方案模型学会预测最优计划(如果使用的训练数据集足够大且多样化),但搜索增强模型在低数据机制下表现明显更好,并且更好地扩展到更困难的任务。搜索增强模型之所以达到更高的性能,是因为它们可以在推理期间进行按需计算。更具体地说,搜索增强模型模仿了导致最优计划的基于接地推理链的搜索动态,而解决方案模型必须通过监督学习来推断任务描述和最优计划之间的直接相关性,其中许多这样的相关性在测试任务集上评估期间可能是虚假的和不可靠的。
4.2 求解Sokoban难题
为了测试是否可以在具有不同令牌化模式和不同转换结构的不同和更复杂的任务上获得类似的结果,我们使用非确定性A*实现重复了Sokoban难题的实验。表1列出了每个模型为每个测试任务生成正确最优计划的频率。与之前一样,通过训练执行跟踪,搜索增强模型优于解决方案模型。即使将解决方案模型的参数化增加到7.47亿个参数,也只能带来微小的性能改进。平均而言,这个7.47亿参数的解决方案模型仍然略逊于更小的1.75亿参数搜索增强模型。这个实验进一步证实了我们在具有不同转换结构和不同令牌化方法的更复杂规划任务上的发现。
4.3 Searchformer:通过自举改进搜索动态
在最后一个实验中,我们研究了如何迭代地改进搜索增强模型,以在生成更短的执行跟踪的同时计算最优计划。在这里,我们的目标是在仍然产生最优解的情况下缩短搜索跟踪的长度。
我们从在非确定性A* Sokoban数据集上训练的最小搜索增强模型开始,并使用它来生成一个新的更短的序列训练数据集,如第3.3节所述。对于训练数据中的每个Sokoban难题,我们通过从Transformer的输出分布中采样标记生成了32个不同的格式序列,如果它包含最优计划,则包括最短的生成(以标记衡量)。随后,我们在这个新创建的训练数据上微调搜索增强模型(通过运行额外的10,000个训练步骤)以获得第一个Searchformer模型。使用这个Searchformer模型,我们随后生成另一个短序列数据集,并重复微调过程以进一步改进模型。
图3(a)说明了Searchformer模型生成的序列长度如何通过我们的搜索动态自举方法迭代缩短。通过每次改进步骤,生成的跟踪的平均长度——搜索步骤的数量——减少(图3(a))。在计算最优计划时,最终的Searchformer模型生成的搜索动态序列平均比最初用于训练的A实现短26.8%。因此,Searchformer模型找到了一种比用于训练初始搜索增强模型的A实现更有效的方法,可以用更少的步骤找到复杂任务的计划。在图3(b)中,我们可以观察到,搜索增强模型生成的序列在长度上平均与A*搜索生成的序列相匹配。Searchformer模型生成更短的序列,导致分布偏向更短的序列长度。
正如表1所报告的,微调模型导致了显著的性能改进,分别将不正确和非最优解的比率降低了40%和30%。成功加权成本(SWC)分数考虑了正确解决的测试任务的数量以及预测计划与最优长度的接近程度(附录D)。在这里,完美分数是1,从表1可以看出,相对较小的Searchformer与最大的解决方案模型性能相当(也注意SEM值很小)。此外,搜索动态改进长度比(ILR)衡量每个执行跟踪的长度缩短了多少(附录D)。随着每次改进迭代,分数增加并攀升到1以上。例如,A*搜索动态比微调3步后Searchformer生成的序列长约34.3%。
图3和表1中报告的结果仅比较了每个模型在正确或最优解决的测试任务上的性能。为了测试模型是否仅在具有更短执行跟踪的较简单测试任务上过拟合,我们在图10中将A*生成的执行跟踪长度与每个模型生成的执行跟踪长度作为散点图绘制。该图中的每个点对应一个测试任务。在这里,通过搜索动态自举缩短执行跟踪的趋势很明显,也可以观察到任何模型都不仅专门解决具有更短执行跟踪的更简单的测试任务。
5. 讨论
先前的工作发现LLMs在解决复杂的决策任务方面存在困难。Searchformer证明,通过适当的训练数据,Transformers实际上可以解决复杂的规划任务。此外,Searchformer稳健地遵循符号规划器的中间步骤(执行跟踪),并在跟踪长度方面超越了最初训练它的人工制定的基于规则的规划策略。与直接预测解决方案的解决方案模型相比,我们的搜索增强模型需要更少的训练序列,并且更好地扩展到更复杂的规划任务。
5.1 局限性
目前,Searchformer是在A*的执行跟踪上训练的,以学习复杂的规划策略。然而,跟踪长度可能随最优计划的长度呈指数增长(见图7),在生成的令牌序列数据上训练可能在计算上变得非常昂贵。事实上,所呈现的实验使用的令牌序列比用于训练Llama 2等LLM的序列要长得多。
5.2 未来工作
缓解这一限制并提高所提方法效率的一种方法是使用课程学习:从具有相当长执行跟踪的简单任务开始,训练和微调Searchformer以缩短跟踪长度,然后将改进的模型适应更复杂的任务。另一种可能性是探索其他规划算法或将更好的启发式或价值函数集成到A*搜索中,类似于MCTS,以限制搜索算法探索的最大深度。集成分层规划方法和时间抽象是另一个途径。这将使生成的模型能够在多个时间步和状态上进行抽象,以使用更少的计算步骤找到最优计划。
与Plansformer相比,所提出的工作演示了如何从头开始训练Transformers在合成数据集上解决复杂的规划任务。我们相信,我们的结果和方法可以与Plansformer等方法相结合,以微调LLM并使其能够更稳健地解决复杂的规划任务。最终,我们希望我们的研究能够阐明如何将Transformers用于多步规划,并希望能够为进一步研究提供信息,以更好地理解LLM的推理能力。
6. 更广泛的影响
我们的工作侧重于符号规划任务,并使用合成数据集进行训练。虽然我们在本文中探索的任务可以用简单的符号求解器轻松解决,但研究神经网络在此类任务上的有效性很重要。在这里,我们提供了一个概念证明,说明如何使用基于Transformer的神经网络来稳健地解决复杂的规划任务。通过我们的工作,我们希望为进一步研究提供信息,以更好地理解大语言模型的推理能力。
总之,这篇论文提出了一种名为Searchformer的方法,通过训练Transformer模型模仿A搜索的搜索动态,然后通过自举逐步缩短搜索轨迹,最终得到一个性能优于A的规划器。实验表明,搜索增强模型在各种规划任务上都优于只预测解决方案的基线模型。这项工作为使用Transformer进行复杂规划提供了新的思路,并为探索语言模型的推理能力开辟了道路。未来的工作可以进一步优化训练效率,将方法应用到更多规划任务中,并与现有方法相结合,以增强大语言模型的规划和推理能力。
Q&A
在这篇论文中,作者使用了两种不同版本的A*算法来生成训练数据:
-
Deterministic A* (确定性A*):
- 在搜索过程中,确定性地对子节点进行排序,并以确定性的方式打破相同代价的节点之间的平局。
- 给定相同的问题,确定性A*总是会生成相同的搜索路径和最优解。
- 训练的模型需要准确地复现确定性A*的搜索过程。
-
Non-deterministic A* (非确定性A*):
- 在搜索过程中,随机地对子节点进行排序,并以随机的方式打破相同代价的节点之间的平局。
- 给定相同的问题,非确定性A*可能会生成不同的搜索路径,但最终的解都是最优的。
- 训练的模型需要学习非确定性A*所隐含的随机性,生成不同但都正确的搜索路径。
使用非确定性A*生成训练数据的主要目的是增加数据的多样性。通过引入随机性,模型可以学习到多种不同但都有效的搜索策略。这可以提高模型的泛化能力,使其在面对新问题时更加鲁棒。
在训练Searchformer模型时,作者首先在非确定性A的数据上训练模型,让其学会对同一个问题生成不同的搜索路径。然后,再通过自助法(Bootstrapping)微调模型,使其生成更短但仍然最优的搜索路径。这种训练方式使Searchformer能够找到比原始A算法更高效的搜索策略。
非确定性A和确定性A的区别在于搜索过程中是否引入随机性。使用非确定性A*生成训练数据可以提高模型的泛化能力和效率。
关键代码进行解读
https://github.com/facebookresearch
AStarTraceIterableDataset
类:
class AStarTraceIterableDataset(IterableDataset):
def __init__(
self,
name: str,
num_sequences: Optional[int] = None,
reasoning_range: Optional[Tuple[int, int]] = None,
shuffle: bool = False,
use_test: bool = False,
load_batch_size: int = 10000,
plan_only: bool = False,
):
...
- 这个类继承自
IterableDataset
,表示一个可迭代的数据集。 __init__
方法接受数据集名称、序列数量、推理范围、是否混洗、是否使用测试集、加载批次大小和是否仅使用计划等参数。- 在
__init__
方法中,根据参数从MongoDB加载相应的数据ID。
def __iter__(self) -> Iterator[AStarTrace]:
worker_info = get_worker_info()
ids_wk = self.ids
if worker_info is not None:
per_worker = math.ceil(len(ids_wk) / worker_info.num_workers)
worker_id = worker_info.id
it_start = worker_id * per_worker
ids_wk = ids_wk[it_start : it_start + per_worker]
if not self.use_test:
batch_loader = self.dataset.train_it(ids_wk, self.load_batch_size)
else:
batch_loader = self.dataset.test_it(ids_wk, self.load_batch_size)
for batch in batch_loader:
tensor_list = self.tokenizer.tokenize_batch(batch, self.plan_only)
if self.shuffle:
random.shuffle(tensor_list)
for trace in tensor_list:
yield trace
__iter__
方法返回一个AStarTrace
对象的迭代器。- 首先获取工作器信息,并根据工作器ID将数据ID分配给各个工作器。
- 根据
use_test
标志,从数据集中加载训练批次或测试批次。 - 对每个批次,使用分词器将其转换为张量列表。
- 如果
shuffle
为 True,则对张量列表进行混洗。 - 最后,逐个生成每个跟踪对象。
NextTokenPredictionLoss
类:
class NextTokenPredictionLoss(nn.Module):
def __init__(self, model: EncoderDecoder):
super().__init__()
self.model = model
self.loss_obj = nn.CrossEntropyLoss(reduction="none")
def forward(
self,
batch: BatchedAStarTrace,
) -> Tuple[Tensor, Dict[str, float]]:
logits = self.model(
prompt=batch.prompt,
prompt_mask=batch.prompt_mask,
trace=batch.trace_plan[:, :-1],
)
logits_1 = logits.reshape(-1, logits.shape[-1])
loss_1 = self.loss_obj(logits_1, batch.trace_plan[:, 1:].reshape(-1))
loss_mat = loss_1.reshape(*batch.trace_plan[:, 1:].shape)
tok_eq = (logits.argmax(-1) == batch.trace_plan[:, 1:]).float()
loss_plan = (loss_mat * batch.plan_mask[:, 1:]).sum(-1)
loss_plan /= batch.plan_mask[:, 1:].sum(-1)
tok_eq_plan = (tok_eq * batch.plan_mask[:, 1:]).sum(-1)
acc_plan = (tok_eq_plan == batch.plan_mask[:, 1:].sum(-1)).float()
trace_seq_len = batch.trace_mask[:, 1:].sum(-1)
if torch.any(trace_seq_len > 0):
loss_trace = (loss_mat * batch.trace_mask[:, 1:]).sum(-1)
loss_trace /= trace_seq_len
tok_eq_trace = (tok_eq * batch.trace_mask[:, 1:]).sum(-1)
acc_trace = tok_eq_trace == batch.trace_mask[:, 1:].sum(-1)
acc_trace = acc_trace.float()
else:
loss_trace = torch.zeros_like(loss_plan)
acc_trace = torch.zeros_like(acc_plan)
mask = batch.trace_mask[:, 1:] + batch.plan_mask[:, 1:]
assert mask.max() == 1.0
loss = (loss_mat * mask).sum(-1)
loss /= mask.sum(-1)
acc_objective = ((tok_eq * mask).sum(-1) == mask.sum(-1)).float()
loss_log = {
"loss.objective": loss.mean().item(),
"loss.trace": loss_trace.mean().item(),
"loss.plan": loss_plan.mean().item(),
"accuracy.trace": acc_trace.mean().item(),
"accuracy.plan": acc_plan.mean().item(),
"accuracy.objective": acc_objective.mean().item(),
}
return loss.mean(), loss_log
- 这个类继承自
nn.Module
,表示一个用于计算下一个标记预测损失的模块。 __init__
方法接受一个EncoderDecoder
模型作为参数,并初始化交叉熵损失函数。forward
方法接受一个BatchedAStarTrace
对象作为输入,并返回损失值和损失日志字典。- 首先,使用编码器-解码器模型计算给定提示和跟踪的对数概率(logits)。
- 然后,计算对数概率与目标跟踪计划之间的交叉熵损失。
- 接下来,分别计算跟踪部分和计划部分的损失和准确率。
- 如果跟踪序列长度大于0,则计算跟踪损失和准确率;否则,将它们设为0。
- 最后,计算总体损失和准确率,并返回平均损失和损失日志字典。
TrainRun
类:
class TrainRun:
def __init__(self, config: TrainConfig):
self.config = config
self.ckpt_data = CheckpointDataset()
self.rank = get_rank()
self.world_size = get_world_size()
vocab_size = self.test_dataset.tokenizer.vocab_size
model_config = EncoderDecoderConfig.from_name(
enc_name=self.config.encoder,
dec_name=self.config.decoder,
vocab_size=vocab_size,
)
torch.cuda.set_device(self.rank % torch.cuda.device_count())
self.model = model_config.construct_model().cuda()
self.loss = DDP(NextTokenPredictionLoss(self.model))
self.optimizer, self.schedule = build_optimizer(
self.model, self.config.optimizer
)
self.step = 0
def train_step(self, batch: Any) -> Dict[str, Any]:
self.optimizer.zero_grad(set_to_none=True)
batch_cuda = batch.cuda()
loss_obj, loss_dict = self.loss(batch_cuda)
loss_obj.backward()
self.optimizer.step()
self.schedule.step()
loss_obj = None
self.optimizer.zero_grad(set_to_none=True)
return loss_dict
def train(self, run_data: TrainRunData):
steps_to_go = self.config.optimizer.train_steps - self.step
if steps_to_go == 0:
logging.info("Run already complete. No further steps to train.")
return
logging.info("Starting training ...")
train_logger = TrainLogger(self.rank)
for batch in repeat_iterator(self.train_dataloader, steps_to_go):
step_result = self.train_step(batch)
train_logger.log(step_result, len(batch))
self.step += 1
if self.step % self.config.log_interval == 0:
log_dict = train_logger.get_log_dict_and_reset(self.step)
lr_list = self.schedule.get_last_lr()
lr_dict = {str(i): lr for i, lr in enumerate(lr_list)}
log_dict["value"]["lr"] = lr_dict
run_data.log_train(self.config.run_id, log_dict)
logging.info(
f"Completed {self.step} steps, "
+ f"lr={self.schedule.get_last_lr()}"
)
if self.step % self.config.eval_interval == 0:
self.evaluate(run_data)
self.checkpoint()
if self.step % self.config.eval_interval > 0:
self.evaluate(run_data)
self.checkpoint()
TrainRun
类表示一次完整的训练运行。__init__
方法接受一个TrainConfig
对象作为参数,初始化训练运行所需的组件,包括检查点数据集、模型、损失函数、优化器和学习率调度器。train_step
方法执行单个训练步骤,包括前向传播、反向传播和优化器更新,并返回损失字典。train
方法执行完整的训练循环,包括训练步骤、日志记录、评估和检查点保存。- 首先计算剩余的训练步骤数,如果为0则表示训练已完成。
- 然后,使用
repeat_iterator
函数遍历训练数据加载器,执行训练步骤。 - 每隔一定间隔(由
log_interval
控制),记录训练日志并将其保存到MongoDB。 - 每隔一定间隔(由
eval_interval
控制),执行评估并保存检查点。 - 如果最后一步不是评估间隔的倍数,则在训练结束后执行一次评估和检查点保存。
以上是对关键代码的逐行解读。这些代码实现了数据集加载、损失函数计算和完整的训练循环,体现了论文中描述的训练过程和关键组件。