(2024,LSTM,Transformer,指数门控,归一化器状态,多头内存混合)xLSTM:扩展的 LSTM

xLSTM: Extended Long Short-Term Memory

公和众和号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)

目录

0. 摘要

1. 简介

2. 扩展的 LSTM

2.1 LSTM 回顾

2.2 sLSTM

2.3 mLSTM

2.4 xLSTM 架构

2.5 内存和速度考虑

4. 实验

5. 限制

6. 结论


0. 摘要

在 1990s,恒定误差旋转门(constant error carousel)和门控(gating)被引入作为长短期记忆(Long Short-Term Memory,LSTM)的核心思想。从那时起,LSTM 经受住了时间的考验,并为许多深度学习成功故事做出了贡献,特别是构成了第一个大型语言模型(LLMs)。然而,随着 Transformer 技术的出现,以可并行化的自注意力为核心,标志着一个新时代的开端,在规模上超越了 LSTM。我们现在提出一个简单的问题:当将 LSTM 的规模扩大到数十亿参数,并利用现代 LLM 的最新技术,同时减轻已知的 LSTM 限制时,我们能在语言建模方面取得多远?首先,我们引入具有适当归一化和稳定化技术的指数门控(exponential gating)。其次,我们修改 LSTM 的内存结构,获得:(i)带有标量内存、标量更新和新的内存混合的 sLSTM,(ii)完全可并行化的 mLSTM,具有矩阵内存和协方差更新规则。将这些 LSTM 扩展集成到残差块骨干中,得到了 xLSTM 块,然后将它们残差叠加到 xLSTM 架构中。指数门控和修改后的内存结构提升了 xLSTM 的能力,使其在性能和规模方面与最先进的 Transformers 和状态空间模型(SSM)相比表现出色。 

1. 简介

LSTM 的思想(Hochreiter, 1991; Hochreiter & Schmidhuber, 1997b,a),即恒定误差旋转门和门控,是为了克服 RNN(Hochreiter, 1991; Hochreiter et al., 2000)的梯度消失问题而引入的:

恒定误差旋转门是单元状态 c_(t−1)(绿色)通过单元输入 zt 的加法更新,并由 sigmoid 门(蓝色)调节。输入门 it 和遗忘门 ft 控制此更新,而输出门 ot 控制内存单元的输出,即隐藏状态 ht 。单元状态由 ψ 归一化或压缩,然后输出门给出隐藏状态。 

尽管 LSTM 取得了巨大的成功,但仍存在三个主要限制:

  • 无法修订存储决策。我们通过最近邻搜索问题来说明这一限制(也见附录 B):给定一个参考向量,必须按顺序扫描序列,以找到最相似的向量,以便在序列末尾提供。图 2 的左侧面板显示了该任务的均方误差。当找到一个更相似的向量时,LSTM 在修订存储值时遇到困难,而我们的新 xLSTM 通过指数门控修复了这个限制。
  • 有限的存储容量,即信息必须压缩到标量单元状态中。我们通过稀有标记预测(Rare Token Prediction)来说明这一限制。在图 2 的右侧面板中,给出了对 Wikite-103(Merity et al., 2017)上的标记预测的困惑度,针对不同标记频率的块(buckets)。由于其有限的存储容量,LSTM 在稀有标记上表现较差。我们的新 xLSTM 通过矩阵存储解决了这个问题。
  • 由于记忆混合导致缺乏并行性,即从一个时间步到下一个时间步的隐藏状态之间的隐藏-隐藏连接,强制进行了顺序处理。

这些 LSTM 的限制为 Transformers(Vaswani et al., 2017)在语言建模中的出现铺平了道路。当克服这些限制并将 LSTM 扩展到当前大型语言模型的规模时,我们能够在语言建模中实现什么样的性能?

2. 扩展的 LSTM

2.1 LSTM 回顾

原始的 LSTM 思想(Hochreiter, 1991; Hochreiter & Schmidhuber, 1997b,a)引入了标量内存单元作为一个中心处理和存储单元,通过恒定误差旋转门(单元状态更新)避免了梯度消失(Hochreiter, 1991; Hochreiter et al., 2000)。内存单元包含三个门:输入门、输出门和遗忘门。遗忘门由 Gers 等人引入(2000年)。在时间步 t,LSTM 内存单元的更新规则为:

其中,

  • 权重向量 w_z, w_i, w_f, 和 w_o 分别对应于输入 x_t 与单元输入、输入门、遗忘门以及输出门之间的输入权重向量。
  • 权重 r_z, r_i, r_f, 和 r_o 对应于隐藏状态 h_{t-1} 与单元输入、输入门、遗忘门以及输出门之间的递归权重。
  • b_z, b_i, b_f, 和 b_o 是相应的偏置项。
  • φ 和 Ψ 是单元输入和隐藏状态激活函数(通常为双曲正切)。Ψ 用于归一化或压缩单元状态,否则将无界。
  • 所有门的激活函数都是 sigmoid 函数,即 σ(x) = 1/(1 + exp(-x))。

在后续的公式中,多个内存单元被合并成一个向量,这允许使用递归权重矩阵来混合内存单元的单元输出(Greff et al., 2015),更多细节请参见附录 A.1。消融研究表明,内存单元的所有组件都至关重要(Greff et al., 2015)。

2.2 sLSTM

为了赋予 LSTM 修订存储决策的能力,我们引入了指数门控(红色)以及归一化和稳定化。特别地,输入门和遗忘门可以具有指数激活函数。对于归一化,我们引入一个归一化器(normalizer)状态,它将输入门与所有未来遗忘门的乘积相加。 

sLSTM的前向传播过程是:

我们将原始的 LSTM 门控技术,即输入和/或隐藏依赖的门控以及偏置项,广播到新的架构中。指数激活函数可能导致产生大值而引起溢出。因此,我们使用额外的状态 m_t(Milakov & Gimelshein, 2018)来稳定门控:  

我们在附录 A.2 中展示,将 ft 替换为 f'_t,以及将 it 替换为 i'_t 在前向传播中既不会改变整个网络的输出,也不会改变损失对参数的导数。

新的内存混合。sLSTM 可以像原始的 LSTM 一样具有多个内存单元(见附录 A.2)。多个内存单元通过从隐藏状态向量 h 到内存单元输入 z 和门 i、f、o 的递归连接 rz、ri、rf、ro 实现内存混合。内存混合的新方面是指数门的影响。新的 sLSTM 可以在每个头部(head)内进行内存混合,但不能跨头部进行混合。引入头部对 sLSTM 的指数门以及内存混合建立了一种新的内存混合方式。

附录:基于 Greff 等人(2015)的标准 LSTM 内存单元更新规则,在时间步 t 将标量单元状态公式扩展为单元状态向量,类似地,sLSTM 也可以向量化为多个单元:

2.3 mLSTM

为了增强 LSTM 的存储容量,我们将 LSTM 内存单元从标量 c ∈ R 增加到矩阵 C ∈ R^(d×d)。因此,检索是通过矩阵乘法执行的。在时间 t,我们想要存储一对向量,即键 k_t ∈ R^d 和值 v_t ∈ R^d(我们使用 Transformer 术语)。稍后在时间 t + τ,值 v_t 应该由查询向量 q_(t+τ) ∈ R^d 检索。这是双向联想记忆(Bidirectional Associative Memories,BAMs)(Kohonen, 1972; Anderson, 1972; Nakano, 1972; Anderson et al., 1977)的设置。存储键-值对的协方差更新规则(Sejnowski, 1977; Dayan & Willshaw, 1991)是

我们假设在将输入投影到键和值之前进行层归一化,因此它们的平均值为零。协方差更新规则是最优的(Dayan & Willshaw, 1991),可实现检索的二进制向量的最大可分性,这等效于最大的信噪比。当将检索限制为成对交互并接受二次复杂度时,更高的可分性是可能的(Krotov & Hopfield, 2016, 2017; Ramsauer et al., 2021)。协方差更新规则等效于快速权重编程器(Schmidhuber, 1992; Schlag et al., 2021),后者已经配备了一个乘以 C_(t−1) 的恒定衰减率和一个乘以 v_t·k^T_t 的恒定学习率(Ba et al., 2016a)。在这个精神上,我们将协方差更新规则集成到 LSTM 框架中,其中遗忘门对应于衰减率,输入门对应于学习率,而输出门缩放检索到的向量。

对于这个矩阵内存,归一化器状态是键向量的加权和,其中每个键向量都由输入门和所有未来遗忘门加权。同样,归一化器状态记录门的强度。由于查询和归一化器状态之间的点积可能接近零,我们使用该点积的绝对值,并将其下限设为一个阈值(通常为 1.0),就像以前一样(Sun et al., 2023)。mLSTM 的前向传播过程是: 

mLSTM 可以像原始的 LSTM 一样具有多个内存单元。对于 mLSTM,因为没有内存混合,多个头部和多个单元是等价的。为了稳定 mLSTM 的指数门,我们使用与 sLSTM 相同的稳定化技术,参见方程(15)。由于 mLSTM 没有内存混合,这种递归可以重新表述为并行版本。更多细节请参阅附录 A.3。 

2.4 xLSTM 架构

xLSTM 块。xLSTM 块应该在高维空间中非线性地总结过去,以更好地区分不同的历史或上下文。分离历史是正确预测下一个序列元素(如下一个标记)的前提。我们诉诸于 Cover 定理(Cover, 1965),该定理指出在高维空间中,非线性嵌入的图样(patterns)更可能被线性分离,而不是在原始空间中。我们考虑两种残差块架构:

  • 一个带有后上投影(post up-projection)的残差块(类似于Transformer),它在原始空间中非线性地总结过去,然后线性映射到高维空间,应用非线性激活函数,然后线性映射回原始空间;见图 3 的左侧面板和图 1 中的第三列。附录中的图 9 展示了更详细的版本。
  • 一个带有前上投影(pre up-projection)的残差块(类似于 SSM),它线性映射到高维空间,然后在高维空间中非线性地总结过去,最后线性映射回原始空间。

对于包含 sLSTM 的 xLSTM 块,我们主要使用后向投影块。对于包含 mLSTM 的 xLSTM 块,我们使用预向投影块,因为在高维空间中的存储容量更大。有关更多细节,请参见图 3 的左侧面板和图 1 的第三列,或附录中的图 9。

图 9:sLSTM 块的示意图 - 后上投影(post up-projection):嵌入在 pre-LayerNorm 残差结构中,输入可以选择通过窗口大小为 4 的因果卷积进行传递,其中包括用于输入门和遗忘门的 Swish 激活。然后,对于所有输入、遗忘和输出门 i、f、o,以及单元更新 z,输入通过一个具有四个对角块或 “头” (Head)的对角线线性层。这些对角块与来自上一个隐藏状态的递归门 pre-activations 相一致,对应于一个具有四个头的 sLSTM,用圆形箭头表示。得到的隐藏状态通过一个 GroupNorm 层(Wu & He, 2018) - 对于每个头部的 LayerNorm。最后,输出通过一个门控 MLP 进行上下投影,使用 GeLU 激活函数和投影因子(PF) 4/3 来匹配参数。 

图10:mLSTM 块的示意图 - 前上投影(pre up-projection):嵌入在 pre-LayerNorm 残差结构中,首先对输入进行上投影,投影因子为 2,一次用于外部化输出门,一次作为 mLSTM 单元的输入。 mLSTM 单元的输入在维度方向上因果卷积(卷积核大小为4)之后,进入可学习的跳跃连接。我们通过块(Block)大小为 4 的块对角投影矩阵获得输入 q 和 k。值 v 直接馈送,跳过卷积部分。在 mLSTM 序列混合之后,通过 GroupNorm(Wu & He, 2018)进行输出归一化(对于每个头的 LayerNorm)。最后,将可学习的跳跃输入添加到结果中,并使用外部输出门对结果进行逐分量门控。然后进行下投影。 

xLSTM 架构。xLSTM 架构是通过残差堆叠构建块(Srivastava等,2015; He等,2016)构建的。我们依赖于当代大型语言模型中最常用的 pre-LayerNorm(Ba等,2016b)残差主干。请参见图 1 中的最后一列。

2.5 内存和速度考虑

与 Transformer 相反,xLSTM 网络具有线性计算和与序列长度相对应的恒定内存复杂度。由于 xLSTM 内存具有压缩性,因此非常适合工业应用和在边缘上的实现。

mLSTM 的记忆不需要参数,但通过其 d×d 矩阵存储和 d×d 更新而在计算上昂贵。我们在内存容量与计算复杂性之间进行权衡。尽管如此,计算可以在 GPU 上并行进行,因此这些计算对墙上时钟时间(wall clock time)的影响很小。

虽然 mLSTM 类似于 FlashAttention(Dao等,2022; Dao,2024)或 GLA(Yang等,2023)可并行化,但由于内存混合(隐藏-隐藏连接),sLSTM 不可并行化。然而,我们开发了一个快速的 CUDA 实现,通过 GPU 内存优化到寄存器级别,通常比 mLSTM 慢不到两倍。 

4. 实验

5. 限制

  • 与 mLSTM 相比,sLSTM 的内存混合阻止了可并行化操作,因此不允许快速的并行实现。尽管如此,我们为 sLSTM 开发了一个快速的 CUDA 核心,目前的速度大约比我们的并行 mLSTM 实现慢了 1.5 倍左右。
  • mLSTM 的 CUDA 核心尚未优化,因此当前的实现速度约为 FlashAttention 或 Mamba 中使用的扫描的 4 倍。可以通过类似于 FlashAttention 的方法获得更快的 CUDA 核心。
  • 因为必须处理 d×d 矩阵,mLSTM 的矩阵内存具有较高的计算复杂性。尽管如此,内存的更新和检索不使用参数,并且可以使用标准矩阵操作进行并行化,因此由于复杂的内存而引起的墙上时钟时间开销很小。
  • 遗忘门的初始化必须谨慎选择。
  • 由于矩阵内存与序列长度有关(原论文中为无关,我认为应该是有关),增加序列长度可能会使较长上下文大小的内存超载。尽管如此,对于长达 16k 的上下文来说,这似乎并不是一个限制,参见第 4.3 节。
  • 由于大型语言实验的昂贵计算负载,我们既没有完全优化架构,也没有优化超参数,特别是对于更大的 xLSTM 架构。我们预计,xLSTM 达到其全部潜力需要进行广泛的优化过程。

6. 结论

我们部分回答了我们的简单问题:将 LSTM 扩展到数十亿个参数时,我们能取得多远的语言建模进展?到目前为止,我们可以回答:“至少与当前的技术(如 Transformer 或 SSM)一样远”。我们通过指数门和内存混合以及新的内存结构将 LSTM 改进为 xLSTM。与 Transformer 和 SSM 等最新方法相比,xLSTM 模型在语言建模方面表现良好。扩展定律表明,更大的 xLSTM 模型将成为使用 Transformer 技术构建的当前大型语言模型的严肃竞争对手。xLSTM 有潜力对其他深度学习领域产生重大影响,如强化学习、时间序列预测或物理系统建模。 

  • 11
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值