【论文阅读】xLSTM: Extended Long Short-Term Memory

xLSTM: Extended Long Short-Term Memory

引用: Beck M, Pöppel K, Spanring M, et al. xLSTM: Extended Long Short-Term Memory[J]. arXiv preprint arXiv:2405.04517, 2024.

论文链接: [2405.04517] xLSTM: Extended Long Short-Term Memory (arxiv.org)

作者: Maximilian Beck, Korbinian Pöppel, Markus Spanring, Andreas Auer, Oleksandra Prudnikova, Michael Kopp, Günter Klambauer, Johannes Brandstetter, Sepp Hochreiter

机构: ELLIS Unit, LIT AI Lab, Institute for Machine Learning, JKU Linz, Austria; NXAI Lab, Linz, Austria; NXAI GmbH, Linz, Austria

摘要

在这里插入图片描述

  • 论文提出了xLSTM,一种扩展的长短期记忆网络,旨在解决传统LSTM的局限性,并在大规模参数下进行语言建模。
  • xLSTM引入了指数门控和适当的归一化与稳定技术,修改了LSTM记忆结构,包括标量记忆的sLSTM和完全可并行化的具有矩阵记忆和协方差更新规则的mLSTM。
  • 通过将这些LSTM变体集成到残差块中,构建了xLSTM架构,这些架构在性能和扩展性方面与最先进的Transformers和状态空间模型相媲美。

引言

  • LSTM自1990年代引入以来,在多个领域取得了成功,特别是在大型语言模型(LLMs)中。
  • 引入Transformer技术后,其并行化的自注意力机制使得LSTM在大规模应用中的性能受到挑战。
  • 论文提出了一个问题:当LSTM扩展到数十亿参数,并结合现代LLMs的最新技术,同时克服LSTM的已知限制时,我们能在语言建模中走多远?

xLSTM架构

1. sLSTM(Scalar LSTM)

指数门控是sLSTM中的一个创新点,它允许模型更有效地更新其记忆状态。在传统的LSTM中,门控机制通常涉及sigmoid函数,但在xLSTM中,输入门( i t i_t it)和遗忘门( f t f_t ft)可以具有指数激活函数:

c t = f t c t − 1 + i t z t c _ { t } = f _ { t } c _ { t - 1 } + i _ { t } z _ { t } ct=ftct1+itzt

n t = f t n t − 1 + i t n _ { t } = f _ { t } n _ { t - 1 } + i _ { t } nt=ftnt1+it

h t = o t h t ~ , h t ~ = o t / n t h _ { t } = o _ { t } \tilde{h _ { t }}, \quad \tilde{h _ { t }} = o _ { t } / n _ { t } ht=otht~,ht~=ot/nt

z t = φ ( z ~ t ) , z ~ t = w z T x t + r z h t − 1 + b z z _ { t } = \varphi ( \tilde { z } _ { t } ), \quad \tilde { z } _ { t } = w _ { z } ^ { T } x _ { t } + r _ { z } h _ { t - 1 } + b _ { z } zt=φ(z~t),z~t=wzTxt+rzht1+bz

i t = e x p ( i ~ t ) , i ~ t = w i T x t + r i h t − 1 + b i i _ { t } = exp ( \tilde { i } _ { t } ), \quad \tilde { i } _ { t } = w _ { i } ^ { T } x _ { t } + r _ { i } h _ { t - 1 } + b _ { i } it=exp(i~t),i~t=wiTxt+riht1+bi

f t = σ ( f ~ t ) O R e x p ( f ~ t ) , f ~ t = w f T x t + r f h t − 1 + b f f _ { t } = \sigma ( \tilde { f } _ { t } ) \quad OR \quad e x p ( \tilde { f } _ { t } ), \quad \tilde { f } _ { t } = w _ { f } ^ { T } x _ { t } + r _ { f } h _ { t - 1 } + b _ { f } ft=σ(f~t)ORexp(f~t),f~t=wfTxt+rfht1+bf

o t = e x p ( o ~ t ) , o ~ t = w o T x t + r o h t − 1 + b o o _ { t } = exp ( \tilde { o } _ { t } ), \quad \tilde { o } _ { t } = w _ { o } ^ { T } x _ { t } + r _ { o } h _ { t - 1 } + b _ { o } ot=exp(o~t),o~t=woTxt+roht1+bo

指数激活函数可能导致较大的值,从而导致溢出。因此,用一个额外的状态 m t m_t mt来稳定门:

m t = max ⁡ ( log ⁡ ( f t ) + m t − 1 , log ⁡ ( i t ) ) m _ { t } = \max ( \log ( f _ { t } ) + m _ { t - 1 } , \log ( i _ { t } ) ) mt=max(log(ft)+mt1,log(it))

i t ′ = e x p ( log ⁡ ( i t ) − m t ) = e x p ( i ~ t − m t ) i _ { t } ^ { \prime } = e x p ( \log ( i _ { t } ) - m _ { t } ) = e x p ( \tilde { i } _ { t } - m _ { t } ) it=exp(log(it)mt)=exp(i~tmt)

f t ′ = e x p ( log ⁡ ( f t ) + m t − 1 − m t ) f _ { t } ^ { \prime } = e x p ( \log ( f _ { t } ) + m _ { t - 1 } - m _ { t } ) ft=exp(log(ft)+mt1mt)

其中,$m_t​ $是稳定状态,用于防止梯度爆炸。

同时,sLSTM引入了新的记忆混合技术,允许在多个内存单元之间进行更复杂的交互。多个存储器单元使得能够分别经由从隐藏状态向量 h h h到存储器单元输入 z z z和门 i i i f f f o o o的循环连接 R z R_z Rz R i R_i Ri R f R_f Rf R o R_o Ro进行存储器混合。sLSTM可以有多个头,每个头内混合内存,但不能跨头混合。

2. mLSTM(Matrix LSTM)

mLSTM使用矩阵记忆来增强存储容量,并通过协方差更新规则来存储关键值对。

C t = f t C t − 1 + i t v t k t T C _ { t } = f _ { t } C _ { t - 1 } + i _ { t } v _ { t } k _ { t } ^ { T } Ct=ftCt1+itvtktT

n t = f t n t − 1 + i t k t n _ { t } = f _ { t } n _ { t - 1 } + i _ { t } k _ { t } nt=ftnt1+itkt

h t = o t ⊙ h ~ t , h ~ t = C t q t / max ⁡ { ∣ n t T q t ∣ , 1 } h _ { t } = o _ { t } \odot \tilde { h } _ { t } , \quad \tilde { h } _ { t } = C _ { t } q _ { t } / \max \left\{ | n _ { t } ^ { T } q _ { t } | , 1 \right\} ht=oth~t,h~t=Ctqt/max{ntTqt,1}

q t = W q x t + b q q _ { t } = W _ { q } x _ { t } + b _ { q } qt=Wqxt+bq

k t = 1 d W k x t + b k k _ { t } = \frac { 1 } { \sqrt { d } } W _ { k } x _ { t } + b _ { k } kt=d 1Wkxt+bk

v t = W v x t + b v v _ { t } = W _ { v } x _ { t } + b _ { v } vt=Wvxt+bv

i t = e x p ( i ~ t ) , i ~ t = w i T x t + b i i _ { t } = e x p ( \tilde { i } _ { t } ) , \quad \tilde { i } _ { t } = w _ { i } ^ { T } x _ { t } + b _ { i } it=exp(i~t),i~t=wiTxt+bi

f t = σ ( f ~ t ) O R e x p ( f ~ t ) , f ~ t = w f T x t + b f f _ { t } = \sigma ( \tilde { f } _ { t } ) \quad OR \quad exp(\tilde { f } _ { t }), \quad \tilde { f } _ { t } = w _ { f } ^ { T } x _ { t } + b _ { f } ft=σ(f~t)ORexp(f~t),f~t=wfTxt+bf

o t = σ ( o ~ t ) , o ~ t = w o T x t + b o o _ { t } = \sigma ( \tilde { o } _ { t } ) , \quad \tilde { o } _ { t } = w _ { o } ^ { T } x _ { t } + b _ { o } ot=σ(o~t),o~t=woTxt+bo

3. xLSTM块(xLSTM Blocks)

在这里插入图片描述
在这里插入图片描述

xLSTM块结合了sLSTM和mLSTM的特性,并通过残差连接来进一步提高性能。对于残差sLSTM块,输入首先进入sLSTM,然后是一个门控的多层感知机(MLP)。对于残差mLSTM块,输入首先通过两个MLP,然后是mLSTM,通过卷积、可学习的跳跃连接和输出门。

4. xLSTM架构(xLSTM Architecture)

xLSTM架构通过残差堆叠xLSTM块来构建,利用了预层归一化(preLayerNorm)残差骨干网络。

实验

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 论文在合成任务和长距离竞技场(Long Range Arena)上测试了xLSTM,并与其他方法进行了比较。
  • 在SlimPajama数据集上进行了语言建模实验,比较了不同方法的性能。
  • 进行了扩展实验,训练了更大的模型,并在更多的训练数据上评估了它们的扩展行为。

结论

  • xLSTM在语言建模方面的表现至少与当前的Transformer或状态空间模型相当。
  • xLSTM有潜力在强化学习、时间序列预测或物理系统建模等深度学习领域产生重大影响。
  • 17
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

煌澄艾

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值