言简意赅学习Learning to (Learn at Test Time): RNNs with Expressive Hidden States

Main Content

所有的序列建模层都可以表示为一个根据更新规则进行转换的隐藏状态。

main idea:隐藏状态设为一个具有权重W的模型f,并将更新规则设为自监督损失ℓ的梯度步。

测试时训练(Test-Time Training, TTT):在测试序列上更新隐藏状态等同于在测试时训练模型f。

隐藏状态是一个自监督学习更新的模型。

相当于把前向传播的参数用一个model来更新

所有的序列建模层(sequence modeling layer)都可以视为该图中三个组件的不同实例:初始状态、更新规则和输出规则。

自注意力机制的隐藏状态随上下文增长,因此每个token的成本也随之增加。而naive RNN和TTT层都将增长的上下文压缩到固定大小的隐藏状态中,因此每个token的成本保持不变。

Details

所有序列建模层都可以从存储历史上下文到隐藏状态的角度进行查看:

RNN(LSTM、RWKV和Mamba)将上下文压缩到固定大小的状态中有两个后果。一方面,将输入token x_t映射到输出token z_t是高效的,因为更新规则和输出规则对每个token的处理时间是恒定的。另一方面,RNN层在长上下文中的性能受到其隐藏状态表达能力的限制。

Attention隐藏状态,通常称为Key-Value (KV) 缓存,是一个随着时间线性增长的列表。其更新规则简单地将当前KV元组附加到这个列表中,而输出规则则扫描直到时间t的所有元组以形成注意力矩阵。隐藏状态显式地存储了所有历史上下文而没有压缩,使得自注意力机制在长上下文中比RNN层更具表现力。然而,扫描这个线性增长的隐藏状态也使得每个token的处理时间线性增长。

为了在长上下文中保持高效且具表现力,我们需要一个compression heuristic将成千上万甚至数百万个tokens压缩到一个能够有效捕捉其底层结构和关系的隐藏状态中

2.1 TTT作为hidden state的更新

自监督训练的模型能够捕捉训练数据背后的底层结构和关系——compression heuristic

大型语言模型(LLMs)就是很好的例子。通过下一步token预测的自监督任务进行训练,其权重可以看作是互联网现有知识的压缩形式。通过查询LLMs,我们可以从其权重中提取知识。更重要的是,LLMs通常表现出对现有知识之间语义联系的深刻理解,以表达新的推理片段。

我们使用自监督学习将历史上下文x1,…,xt压缩到隐藏状态s_t中,将上下文看作是一个无标签的数据集,将状态看作是一个模型。具体来说,隐藏状态s_t现在等同于权重Wt,模型f可以是线性模型、小型神经网络或其他任何形式。输出规则简单为:

输出token只是由更新后的权重Wt生成的对xt的预测。自监督损失ℓ的梯度下降更新:

W 记住那些产生大梯度的输入,一种loss(ℓ)的选择是重建xt本身。为了使学习问题变得非平凡,我们首先将x_t处理成一个损坏的输入x_t,然后优化:

类似于去噪自动编码器,f需要体现x_t各维度之间的相关性,以便从部分信息x_t中重建它。

2.2 使用TTT层训练网络

TTT层与RNN层和自注意力机制有相同的接口,因此可以替换到任何更大的网络架构中,通常包含许多这种序列建模层。使用TTT层训练网络的方式也与训练任何其他语言模型(如Transformer)相同。相同的数据、配方和目标(如下一步token预测)可以用于优化网络的其余参数。

我们将训练包含TTT层的更大网络称为外循环,而在每个TTT层内训练W则称为内循环。这两个嵌套学习问题的一个重要区别在于,内循环梯度 ∇ℓ是相对于W(即f的参数)计算的,而外循环梯度是相对于网络其余部分的参数(记作 θrest)计算的。

到目前为止,TTT层还没有外循环参数,这与其他RNN层和自注意力机制不同。在2.3小节中,我们为TTT层添加外循环参数以改进其自监督任务。

2.3 TTT自监督任务

TTT最重要的部分,决定了W将从测试序列中学习到的特征。

TTT的最终目标是使

在语言建模中表现良好。

采取了一种更端到端的方法——直接优化自监督任务以实现下一步token预测的最终目标。

我们在外循环中学习自监督任务。从loss函数(ℓ)开始添加一些外循环参数以使该任务可学习。

在2.1 TTT更新hidden state中,我们没有指定将x_t转换为x_t的损坏过程。一种设计是使其成为低秩投影xt=θKxt,其中θK是一个可学习的矩阵。根据多视图重建的术语,θKxt被称为训练视图。

可能并非x_t中的所有信息都值得记住,因此重建标签可以是另一个低秩投影θVxt而不是x_t。这里

θVxt被称为标签视图,其中θV也是可学习的。新的自监督损失为:

⚠️W和θ的区别:

在内循环中,只有W被优化,因此写成ℓ的参数;而θ作为该损失函数的“超参数”。

在外循环中,θK、θV、θQ与θrest一起被优化(类似于自注意力的Key和Value参数),而W仅仅是一个隐藏状态,不是一个参数。

由于训练视图θKxt的维度小于x_t,我们不能再使用方程1中的输出规则。最简单的解决方案是创建一个测试视图θQxt,并将我们的输出规则改为:

好处:训练和标签视图指定了x_t中的信息,这些信息被压缩到W_t并通过向前传播。测试视图指定了可能不同的信息,这些信息被映射到当前的输出token z_t 并通过网络层向前传播,从而为自监督任务增加了更多的灵活性。

2.4 使用mini-batch TTT进行并行化

然而,其更新规则不能并行化:

其中Gt是下降方向。注意到一旦我们计算了Gt对于t=1,…,T,我们就可以通过GD的一般更新规则的后半部分获取所有的Wt。

mini-batch梯度下降:

其中节点是变量,边是计算。蓝色节点是输入变量,黄色是输出。由于 G1,。, Gb 没有连接,它们彼此没有顺序依赖,因此可以并行计算。

2.5 Dual Form

  • 原始 TTT 层更新规则和输出规则包含大量 matmul,导致硬件利用率低下。权重矩阵维度较大,增加了内存占用和 I/O 成本。
  • Dual form 是种优化 TTT 层硬件效率的方法,通过减少矩阵乘法 (matmul) 数量来提高性能。通过避免显式计算中间变量(例如梯度),使用矩阵乘法计算最后一个时间步的权重和输出,利用矩阵乘法计算中间变量(例如掩码矩阵)。减少 matmul 的数量,提高硬件利用率。降低内存占用和 I/O 成本。

论文链接:https://arxiv.org/abs/2407.04620

代码链接:GitHub - test-time-training/ttt-lm-pytorch: Official PyTorch implementation of Learning to (Learn at Test Time): RNNs with Expressive Hidden States

欢迎讨论!看的人多的话下次更TTT的代码详解

  • 12
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值