知乎:刀刀宁
链接:https://zhuanlan.zhihu.com/p/718156896
线性注意力机制的文章有很多了,在本篇笔记中,我们简单地对各种方法进行一下图解比较,串一下当前的线性注意力机制,涉及的公式极少,主要梳理逻辑脉络。本文会从 state space model 中间状态模型这条主线,来梳理 RNN、LSTM,再到 Retentive、GLA 等 Linear Attention 的改进版,最后再到 Mamba、Mamba-2、RWKV 等方法。
线性注意力机制的好处很多,可以用“多快好省”来形容:理论复杂度低、速度快、结构简单、上下文长度线性依赖、KVCache 不需要额外存储,且优化容易。但相比 full attention,线性注意力机制的表达能力确实差一截,且无法完全丢弃历史信息,类似于 RNN 的遗忘和依赖关系,因此产生了各种改进方法。
同时,线性注意力也具备很多并行和 IO 感知的优化,否则复杂度线性化后,并行和运算速度若不如 full attention,就显得鸡肋。因此,如何结合硬件(主要是 CUDA GPU 的特点)来进行注意力机制的系统级优化是不可忽略的问题。
Part 1: Linear Attention 与非必要 softmax
Linear Attention Transformers (Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention) 的论文发表于 2020 年。
https://proceedings.mlr.press/v119/katharopoulos20a/katharopoulos20a.pdf
为了帮助理解,先引用苏神 2020 的文章:《Attention 必须有个 Softmax 吗?》以及《超无聊的:Gated Linear Attention Transformers with Hardware-Efficient Training》,来看一下去掉 softmax 函数后的 attention 机制。这里省略公式和证明,感兴趣的读者可移步前文。
https://spaces.ac.cn/archives/7546 https://zhuanlan.zhihu.com/p/672824235
下图左图是原来的 attention 机制,矩阵乘法的顺序和计算复杂度:设序列长度为 ,当前复杂度为 级,这是我们熟悉的情形。而右图则去掉了 softmax,用近似函数 sim
替代,并改变了 QKV 的计算顺序(本文中的典型线性注意力机制)。这时,神奇的事情发生了:中间结果从 的矩阵变成了 ,复杂度变成了 线性(当然,若 是 4096 级别, 也很大,此时还需考虑减小 等方法)。但整个运算过程与 的长度呈线性相关性。
这只是 softmax 的原因吗?最初认为是,但深入研究后发现&#x