【论文笔记】线性注意力:Learning to (Learn at Test Time): RNNs with Expressive Hidden States

参考文献:Learning to (Learn at Test Time): RNNs with Expressive Hidden States

动机(Motivation)

自注意力在长上下文中表现良好,但具有二次复杂度(无论是transformer还是mamba都涉及矩阵的乘积运算)。现有的 RNN 层具有线性复杂度,但它们在长上下文中的性能受到隐藏状态表达能力的限制。因此,希望提出一种序列建模层,具有线性复杂性和富有表现力的隐藏状态。

创新点

关键思想是将隐藏状态本身作为 机器学习模型,将更新规则作为自监督学习的步骤。由于隐藏状态甚至可以通过测试序 列上的训练来更新,因此我们的层称为Test-Time Training(TTT)层。

方法论

所有序列建模层都可以从将历史上下文存储到隐藏状态的角度来看待,如图 4 所示。

RNN:压缩启发式(compression heuristic),简单来说就是,在隐藏状态层将历史上下文进行压缩存储,性能收到隐藏状态表达能力的限制,因为只会存储固定大小的内容

self-attention:隐藏状态随着t增长,将计算所得的Key-value缓存下来,从而实现显式的存储历史上下文而不进行压缩。时间消耗会随着隐藏状态的增长而增长。

期望实现的目标:更好的压缩启发式,将数千甚至数百万个令牌压缩到隐藏状态,以有效捕获其底层结构和关系。

将TTT用于更新隐藏状态

启发:参数学习的过程可以看作是将大量的数据集压缩成模型权重的过程。而自监督学习训练的模型可以捕获训练数据背后的底层结构和关系。

关键思想:将历史上下文x_1,x_2,…,x_t压缩到隐藏状态$s_t$,这一过程通过自监督学习进行,即将历史上下文x视为没有label的数据集,而将隐藏状态视为模型,也就是将隐藏状态视作自监督模型产生的模型f的权重$W_t$。f可以是任意模型,包括线性模型或者神经网络。

那么输出规则$z_t=f(x_t;W_t)$

更新规则:$Wt=W_{t-1}-\eta\nabla\ell(W_{t-1};x_t)$,其中$\eta$为学习率,$\ell$为自监督损失

压缩启发式通常需要选择记住或者遗忘哪些输入,从更新规则不难看出,TTT能够记住一些产生较大梯度的输入,直观地讲,这些输入使 W 学到了很多东西。

对于$\ell$的选择,一种选择是进行重构,$\ell(W;x_t)=|f(\tilde{x}_t;W)-x_t|^2$

与其他 RNN 层和自注意力机制一样,算法映射输入序列到输出序列可以使用上面的隐藏状态、更新规则和输出规则编程到序列建模层的前向传递中。由于在测试过程中,对于每个输入序列,我们的新层仍然训练不同的权重序列,因此被称为TTT(Test-Time Training)

使用TTT层训练网络

TTT 层与 RNN 层和自注意力机制具有相同的接口,因此可以在任何更大的网络架构 中进行替换。

代码描述

我们将训练较大的网络称为外循环,将每个 TTT 层中的进行的$W$更新的训练称为内循环。区别是在内循环中,梯度$\nabla\ell$可以更新W,也就是模型f的参数,而在外循环中则是更新网络其余部分的参数,$\theta_{\mathrm{rest}}$,在接下来的内容中,外循环的参数使用不用下标的$\theta$表示。

TTT的自监督学习任务(最重要的部分)

创新:不使用人类先验的自监督任务,而是采用更加端到端的方法---直接优化自监督任务以实现下一个token预测的最终目标。

具体来说,将自监督学习任务作为外循环的一部分。

$\ell(W;x_t)=|f(\tilde{x}_t;W)-x_t|^2$

对于从$xt$得到$\tilde{x}t$,我们可以进行低秩投影,即\tilde{x}_{t}=\theta_{K}x_{t},$\theta_{K}$为可学习(learn-able)矩阵,\theta_{K}x_{t}称为训练视图(training view)。

直观的来看,并不是xt中的所有信息都需要被记忆,那么重建标签(reconstruction)可以另一个低秩投影\theta_{V}x_{t}而可以不是$x_t$,其中$\theta_{V}$是可学习的,\theta_{V}x_{t}称为标签视图(label view)。因此,自监督损失可以写作 \ell(W;x_t)=\begin{Vmatrix}f\left(\theta_Kx_t;W\right)-\theta_Vx_t\end{Vmatrix}^2

由于训练视图\theta_{K}x_{t}$x_t$的低秩投影,故不能直接使用z_t=f(x_t;W_t)的输出规则。因此我们构造一个测试视图,\theta_Qx_t,并使用z_t=f\left(\theta_Qx_t;W_t\right)进行输出。

进一步解释,在内循环部分,只有$w$被更新,因此写作损失$\ell$的参数,而在外循环中,$w$只作为隐藏状态而不更新,\theta_K\theta_V$\theta$一起被更新。

小批量的TTT并行计算

由于TTT的更新规则W_t=W_{t-1}-\eta\nabla\ell(W_{t-1};x_t)依赖于$W_{t-1}$,因此并不能并行化。而更新规则中,主要的计算集中在$\nabla\ell$中,故主要对\nabla\ell进行并行化的设计。

梯度下降可以表示成

W_t=W_{t-1}-\eta G_t=W_0-\eta\sum_{s=1}^tG_s

因此只需要知道$G_t$就可以计算出所有的$W_t$,令G_t=\nabla l(W_{t-1};x_t)

为了进行并行化的计算,我们可以考虑G_t=\nabla\ell(W_0;x_t),这可以看作是进行批量梯度下降(batch gradient descent)。但这种方法的弊端是,W_t相比于W_0实际上只进行了一次梯度计算,因此有效搜索空间较小,会影响性能。

因此,提出了小批量梯度下降(mini-batch gradient descent)。令G_{t}=\nabla\ell\left(W_{t^{\prime}};x_{t}\right),{\mathrm{where~}t^{\prime}}=t-\mathsf{mod}(t,b),其中b是TTT的批量大小。简单来说,就是将数据划分成n个批量大小为b的子集,在子集中进行并行化的批量梯度下降计算,以此来减少浮点运算数(FLOP) .

对于批量大小b的选择,需要进行速度和质量之间的权衡,原文进行了实验,得到了如图的结果,故采取了b=16。

对偶形式

现有的加速器(accelerators)针对矩阵乘法进行了专门的研究,称为matmuls。如NVIDIA A100 GPU中包含称为 TensorCore 的高度优化单元,该单元只能执行单个操作 - 将两个大小分别为 16 × 16 的矩阵相乘。如果没有足够的这些 matmul,TensorCore 就会闲置。

然而在TTT层中,即便使用了mini-batch仍只有少量matmul。考虑最简单的情况,即对于第一个mini-batch,\theta_K=\theta_V=\theta_Q=I。假如我们考虑线性的模型$F$,那么在时间t时的损失为\ell\left(W_0;x_t\right)=|f\left(x_t;W_0\right)-x_t|^2=|W_0x_t-x_t|^2,正如我们上面所讨论的,G_t=\nabla\ell\left(W_0;x_t\right)=2(W_0x_t-x_t)x_t^T,其中t=1,...,b。我们无法通过一个matmul进行计算,相反我们需要b个进行计算。与此同时,对于x_{t}\in\mathbb{R}^{d},$Gt$的大小为$d \times d$,这意味着对于较大的d,需要更大的内存占用和I/O占用。

通过对流程的观察我们可以发现,我们的目标是$W_b$,并不需要对$G_i$进行具体化的表示,故我们可以进行简化,

W_b=W_0-\eta\sum_{t=1}^bG_t=W_0-2\eta\sum_{t=1}^b(W_0x_t-x_t)x^T=W_0-2\eta(W_0X-X)X^T

其中$X=[x_{1},\ldots,x_{b}]$

那么输出zt就可以简化为

z_t=f(x_t;W_t)=W_tx_t=\left(W_0-\eta\sum_{s=1}^tG_t\right)x_t=W_0x_t-2\eta\sum_{s=1}^t(W_0x_s-x_s)x_s^Tx_s

\delta_{t}=\sum_{s=1}^{t}(W_{0}x_{s}-x_{s})x_{s}^{T}x_{s}\Delta=[\delta_{1},\ldots,\delta_{b}],那么\Delta=\text{mask}\left(X^TX\right)\left(W_0X-X\right),其中 mask 是带有零的下三角掩模(类似于注意掩模,但用零而不是无穷大),W_0X-X可以重复使用,因此$\Delta$可以用 matmuls 方便地计算,故Z=W_{0}X-2\eta\Delta

优势:不使用对偶形式时,时间复杂度为O(b\times d^{2}),使用对偶形式后,计算W_b的复杂度为O(b\times d^{2}),计算Z为O(b^{2} \times d),是利用理论复杂性来换取硬件的利用率,从而带来速度的提升(因为d通常为几百,而b为16)

理论等价

等价一

考虑模型f为线性模型的TTT层,以学习率$\eta$=1/2的批量梯度下降,$W_0=0$,那么按照输出规则z_t=f\left(\theta_Qx_t;W_t\right)得到的结果与线性注意力的结果等价。

证明:\ell(W;x_t)=|W\theta_Kx_t-\theta_Vx_t|^2,对$w$求导可得\nabla_W\ell(W;x_t)=2\left(W\theta_Kx_t-\theta_Vx_t\right)(\theta_Kx_t)^T,代入w=w_0=0,可得\nabla\ell\left(W_0;x_t\right)=-2(\theta_Vx_t)(\theta_Kx_t)^T。代入$W_t$的计算式,可得W_{t}=W_{t-1}-\eta\nabla\ell\left(W_{0};x_{t}\right)=W_{0}-\eta\sum_{s=1}^{t}\nabla\ell\left(W_{0};x_{s}\right)=\sum_{s=1}^{t}(\theta_{V}x_{s})(\theta_{K}x_{s})^{T}

那么$z_ t$可得为z_t=f\left(\theta_Qx_t;W_t\right)=\sum_{s=1}^t(\theta_Vx_s)(\theta_Kx_s)^T(\theta_Qx_t),等价于线性注意力

等价二

对于非参数的模型f,不存在$w_t$,故我们使用符号f(x;x_{1},\ldots,x_{t})。使用Nadaraya-Watson estimator考虑TTT层,定义f(x;x_1,\ldots,x_t)=\frac{1}{\sum_{s=1}^t\kappa(x,x_s)}\sum_{s=1}^t\kappa(x,x_s) y_s,其中y_{s}=\theta_{V}x_{s}\kappa\left(x,x';\theta_K,\theta_Q\right)\propto e^{(\theta_Kx)^T\theta_Qx'}为具有\theta_K,\theta_Q超参数的核函数,那么按照输出规则获得的z_t与自注意力等价

证明:将y_{s}\kappa(x,x_s)代入式子,即得到自注意力的定义。

推导过程:自注意力z_t=\sum_{s=1}^t\text{softmax}\left(\frac{(\theta_Kx_s)^T\theta_Qx_t}{\sqrt{d_k}}\right)y_s

将核函数代入f(x;x_1,\ldots,x_t)=\frac1{\sum_{s=1}^t\exp\left((\theta_Kx)^T\theta_Qx_s\right)}\sum_{s=1}^t\exp\left((\theta_Kx)^T\theta_Qx_s\right)y_s

由于\mathrm{softmax}(x)=\frac{\exp(x)}{\sum\exp(x)},观察可得,这与Nadaraya-Watson估计器中通过核函数生成的权重形式是等价的。因此,当核函数 κ(x,xs​) 取为上述形式时,Nadaraya-Watson估计器的输出与自注意力机制的输出形式一致。

实验结果

短文本:Pile数据集

在2k 上下文的条件下下,TTT-Linear (M)、Mamba 和Transformer 具有相当的性能,因为线 条大部分重叠。

在 8k 上下文背景下,TTT-Linear (M) 和 TTT-MLP (M) 的表现均明显优于 Mamba,与 2k 下的观察结 果相反。即使具有 Transformer 主干的 TTT-MLP (T) 性能也比 Mamba 略好,约为 1.3B。我们在 本文中观察到的一个强有力的现象是,随着上下文长度变长,TTT 层相对于 Mamba 的优势扩大。同时,Transformer 在每个模型大小上仍然具有良好(如果不是最好)的困惑度,但由于 FLOP 成本, 其产品线不具有竞争力。

长上下文:Books数据集

使用称为 Books3 的 Pile 的子集。

在 Books 的 2k 上下文中,Pile Пk 的所有观察结果仍然成立,除了 Mamba 现在的表现略好于 TTT-Linear(而 它们的线在 Pile Пk 中大致重叠)。

在 32k 上下文中,TTT-Linear (M) 和 TTT-MLP (M) 的性能均优于 Mamba,与 Pile Хk 的观察 结果类似。即使具有 Transformer 主干的 TTT-MLP (T) 在 32k 上下文中的表现也比 Mamba 稍好。

在1.3B 尺度上,TTT-MLP (T) 仅比TTT-MLP (M) 稍差。正如所讨论的,由于缺 乏清晰的线性拟合,很难推导出经验缩放定律。然而,TTT-MLP (T) 的强劲趋势表 明 Transformer 主干可能更适合超出我们评估的更大模型和更长上下文。

  • 19
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值