The Surprising Effectiveness of Test-Time Training for Abstract Reasoning
基本信息
博客贡献人
JokerLin
作者
Ekin Akyürek,Mehul Damani,Linlu Qiu,Han Guo,Yoon Kim,Jacob Andreas
标签
LMs,Test-time Training,Abstraction and Reasoning Corpus,Neural Network
摘要
语言模型在其训练分布中的任务上表现出令人印象深刻的性能,但经常难以解决需要复杂推理的新问题。本文以 ARC 为基准,研究了测试时训练 (Test-Time Training,TTT) 的有效性——在推理过程中使用从输入数据得出的损失临时更新模型参数——作为提高模型推理能力的机制。通过系统实验,确定了成功 TTT 的三个关键组成部分:(1)对类似任务进行初始微调 (2)辅助任务格式和增强 (3)逐个实例训练。TTT 显著提高了 ARC 任务的性能,与基本微调模型相比,准确率提高了约 6 倍;将 TTT 应用于 8B 参数语言模型,在 ARC 的公共验证集上实现了 53% 的准确率,将纯神经方法的最新技术提高了近 25%。通过将本文的方法与最近的程序生成方法集成,获得了 61.9% 的 SoTA 公共验证准确率,与人类的平均分数相匹配。本文的研究结果表明,显式符号搜索并不是改进神经语言模型中抽象推理的唯一途径;将额外的测试时间应用于对 few-shot 示例的继续训练也可能非常有效。
引言
大规模神经语言模型 (LMs) 擅长执行其训练数据中出现的任务,并且通常是这些任务的基本变体或组合。给定自然语言任务规范或少量示例,LMs 通常会成功推断出所需的任务并生成适当的输出。对于复杂和新颖的任务,通常很难仅通过从 LMs 中抽样来获得正确答案。
然而近年来的一个重要发现是,可以通过额外的测试时计算来增强 LM 解码,从而显著提高 LM 性能。这类方法包括思维链提示、多数投票抽样、代码执行和搜索。最近受到关注的一种扩展策略是测试时训练 (TTT),其中模型通过基于测试时输入的显式梯度步骤进行更新。这种方法与标准微调不同,因为它在极低数据范围内运行。
本文在 ARC 中评估了这些方法,这是一组极具挑战性的 few-shot 视觉推理问题。ARC 是测试 LM 泛化极限的理想基准,因为它以新颖的形式呈现新颖的任务,需要重要的搜索和推理功能。当前语言模型在 ARC 上的性能不佳,大多数成功的方法都依赖于程序合成技术。
通过仔细选择这些组件,TTT 可以显著提高 ARC 上的 LM 性能——与 1B 模型相比,准确性提高了六倍,并使用 8B 模型在 ARC 任务上为已发布的纯神经模型获得最先进的结果。事实上,本文的结果表明,当配备测试时训练时,普通 LM 可以达到或超过 ARC 上许多神经符号方法的性能。
本文的贡献如下:
- 通过新颖的测试时训练数据生成和自洽组件,确定并系统地分析了 ARC 任务的测试时训练所需的关键组件
- 在 ARC 验证集上已发布的神经方法中取得了最先进的结果:
- 使用 8B 参数模型的公共验证集的准确率为 53%
- 与程序综合方法集成时,准确率为 61.9%,与数据集上的平均人类表现相匹配
- 证明以前只能通过程序综合解决的任务,可以用配备 TTT 框架的全神经方法来解决
这些结果挑战了符号分量对于解决此类复杂任务是绝对必要的假设。
简介
ARC挑战
ARC 旨在通过语言模型解决视觉难题的能力来评估语言模型的抽象推理能力。每个任务由二维网格的输入输出对(最大 30 × 30 个)组成,其中包含由多达 10 种不同颜色组成的形状或图案,如图 1(b) 所示。每对的输出是通过应用直观且共享的转换规则或函数 y = f (x) 获得的。
ARC 中的每个任务都由训练和测试拆分组成,其中包含:
- 训练示例表示为 ( x k t r a i n , y k t r a i n ) k = 1 K (x^{train}_{k},y^{train}_k)^K_{k=1} (xktrain,yktrain)k=1K(通常 K 范围为 2 到 7)。
- 测试示例表示为 ( x m t e s t , y m t e s t ) m = 1 M (x^{test}_{m},y^{test}_m)^M_{m=1} (xmtest,ymtest)m=1M(通常 M 范围为 1 到 3)。
给定这组训练示例,目标是通过推理底层转换来预测测试输入 x t e s t x_{test} xtest 的测试输出 y t e s t y_{test} ytest。
将任务表示为 d = ( x t r a i n , y t r a i n , x t e s t , y t e s t ) d = (x_{train}, y_{train}, x_{test}, y_{test}) d=(xtrain,ytrain,xtest,ytest),其中 d ∈ D A R C d \in D_{ARC} d∈DARC,即此类 ARC 任务的集合。ARC 数据集的原始训练集和验证集,分别为 D A R C t r a i n D^{train}_{ARC} DARCtrain 和 D A R C v a l D^{val}_{ARC} DARCval,每个数据集由 400 个任务组成。 成功标准需要为所有测试输出生成完全匹配项。
大多数 ARC 方法可以分为两大类:程序综合和完全神经。程序综合方法尝试先找到变换函数 f f f,然后再将其应用于测试示例。另一方面,完全神经方法尝试直接预测输出 y t e s t y_{test} ytest,仅隐含地推理潜在的转换。在这项工作中,本文使用一种完全神经的方法,使用 LM 来预测测试输出。
上下文学习
在一定规模上,许多 LM 表现出适应新任务的能力,而无需通过简单地调节提供的输入示例或指令来更新其参数。给定一系列输入-输出对 ( x 1 , y 1 ) , . . . , ( x n , y n ) (x_1, y_1), ..., (x_n,y_n) (x1,y1),...,(xn,yn) 和一个新的输入 x n + 1 x_{n+1} xn+1,LM 可生成输出 y ^ n + 1 \widehat y_{n+1} y
n+1:
y ^ n + 1 = L M ( ⋅ ∣ x 1 , y 1 , . . . x n , y n , x n + 1 ) (1) \widehat y_{n+1} = LM(\cdot | x_1,y_1,...x_n,y_n,x_{n+1})\tag{1} y
n+1=LM(⋅∣x1,y1,...xn,yn,xn+1)(1)
上下文学习与任何标准的机器学习算法都不同,并且它不能开箱即用地用于新任务。
测试时训练 TTT
一般的 TTT 过程如下: