2023 | RetNet:全新大模型基础架构,Transformer的继任者

#Transformer#

近年来,Transformer已经成为了大型语言模型普遍采用的架构。虽然该架构在训练过程中能够实现出色的并行性和卓越性能,但在推理阶段却面临较高的成本。为了解决这一问题,文章引入了一种创新的网络结构,即RetNet。相对于传统的Transformer架构和多种变体,RetNet架构的优势是同时具备三个特点:训练可并行、推理成本低、具备良好的性能。

原文标题:Retentive Network: A Successor to Transformer for Large Language Models
作者信息:Yutao Sun,Li Dong,Shaohan Huang,Shuming Ma,Yuqing Xia,Jilong Xue,Jianyong Wang,Furu Wei(作者团队来自微软研究院和清华大学)
论文链接:https://arxiv.org/abs/2307.08621

1.介绍

在深度学习领域,循环神经网络(RNNs)以序列方式逐一处理输入数据,某个时间步骤上的输入处理取决于前一个时间步骤的隐藏状态,因此无法进行并行计算,从而降低了训练速度。而Transformer则采用了高度可并行化的自注意力机制,使得每个时间步的输出能够以Q、K、V矩阵的方式进行并行处理。不过,这种自注意力机制有助于Transformer在GPU上实现出色的并行性,但也导致了推理过程中的高成本。

研究者们一直在努力开发新一代架构,其目标是在保持训练并行性和Transformer性能的同时,实现高效的推理。要同时实现上述目标(即下图的“不可能三角“)是一项极具挑战性的任务。
在这里插入图片描述
该文提出了一个新的大语言模型自回归基础架构 Retentive Networks (RetNet),解决了“不可能三角”挑战。RetNet 在正中间,表示同时具备三个优点:推理成本低、训练可并行、良好的性能。而其他的架构 Linear Transformer、Recurrent Network 和 Transformer 都只能同时具备其中两个优点。

2.模型框架

RetNet 架构和 Transformer 类似,也是由L个相同的块堆叠而成。每一个 RetNet 块包含两个部分:一个multi-scale retention (MSR) 模块和一个feed-forward network (FFN) 模块。整体架构如下面的公式所示:
Y l = M S R ( L N ( X l ) ) + X l X l + 1 = F F N ( L N ( Y l ) ) + Y l \begin{gather*} Y^l=MSR(LN(X^l))+X^l \\ X^{l+1}=FFN(LN(Y^l))+Y^l \end{gather*} Yl=MSR(LN(Xl))+XlXl+1=FFN(LN(Yl))+Yl
输入序列 { x i } i = 1 ∣ x ∣ \{x_i\}_{i=1}^{|x|} {xi}i=1x通过一个词嵌入层转换为向量。然后使用打包好的嵌入 X 0 = [ x 1 , . . . , x ∣ x ∣ ] ∈ R ∣ x ∣ × d m o d e l X^0=[x_1,...,x_{|x|}]\in \mathbb{R}^{|x|\times d_{model}} X0=[x1,...,xx]Rx×dmodel作为输入,计算模型的输出 X L X_L XL

公式中LN(.)是LayerNorm。FFN部分,采用下面的公式计算:
F F N ( X ) = g e l u ( X W 1 ) W 2 FFN(X)=gelu(XW_1)W_2 FFN(X)=gelu(XW1)W2,其中 W 1 W_1 W1 W 2 W_2 W2是参数矩阵。

在后面主要介绍MSR模块。

2.1 Retention

首先对词嵌入向量X序列的第n个时间步的向量乘以权重 ω \omega ω,得到投影 v n v_n vn v ( n ) = X n ⋅ ω v v(n)=X_n \centerdot \omega_v v(n)=Xnωv

然后类似Transformer架构,计算Q和K的投影: Q = X W Q , K = X W K Q=XW_Q,K=XW_K Q=XWQK=XWK

接着假设一个序列建模的问题,通过状态 s n s_n sn v n v_n vn映射为 o n o_n on向量,以递归的方式定义映射:

s n = A s n − 1 + K n T v n , A ∈ R d × d , K n ∈ R 1 × d o n = Q n s n = ∑ m = 1 n Q n A n − m K m T v m , Q n ∈ R 1 × d \begin{gather*} s_n=As_{n-1}+K_n^Tv_n,&A\in\mathbb{R}^{d\times d},K_n\in\mathbb{R}^{1\times d}\\ o_n=Q_ns_n=\sum_{m=1}^nQ_nA^{n-m}K_m^Tv_m,&Q_n\in\mathbb{R}^{1\times d} \end{gather*} sn=Asn1+KnTvn,on=Qnsn=m=1nQnAnmKmTvmARd×d,KnR1×dQnR1×d

其中,A是一个矩阵, K n K_n Kn表示时间步n对应的K投影,类似地, Q n Q_n Qn表示时间步n对应的Q投影。

接下来,利用对角化简化方程: A = Λ ( γ e i θ ) Λ − 1 A=\Lambda (\gamma e^{i\theta})\Lambda^{-1} A=Λ(γeiθ)Λ1,得到新的 o n o_n on表达式:
o n = ∑ m = 1 n Q n ( γ e i θ ) n − m K m T v m = ∑ m = 1 n ( Q n ( γ e i θ ) n ) ( K m ( γ e i θ ) − m ) T v m \begin{gather*} o_n=\sum_{m=1}^nQ_n(\gamma e^{i\theta})^{n-m}K_m^Tv_m\\ =\sum_{m=1}^n(Q_n(\gamma e^{i\theta})^n)(K_m(\gamma e^{i\theta})^{-m})^Tv_m \end{gather*} on=m=1nQn(γeiθ)nmKmTvm=m=1n(Qn(γeiθ)n)(Km(γeiθ)m)Tvm

其中, Q n ( γ e i θ ) n Q_n(\gamma e^{i\theta})^n Qn(γeiθ)n K m ( γ e i θ ) − m K_m(\gamma e^{i\theta})^{-m} Km(γeiθ)m是xPOS,一种为transformer设计的位置编码。

再将 γ \gamma γ定义为一个标量,则可以将上述公式进一步简化为:
o n = ∑ m = 1 n γ n − m ( Q n e i n θ ) ( K m e i m θ ) † v m o_n=\sum_{m=1}^n \gamma^{n-m}(Q_ne^{in\theta})(K_me^{im\theta})^{\dagger}v_m on=m=1nγnm(Qneinθ)(Kmeimθ)vm

公式中的 † \dagger 表示共轭转置操作。通过以上的递归、对角化、标量简化这几个步骤,即得到了 Retention 最基本的形式。之后作者又给出了三种表示形式:

① Retention的训练并行表示
并行的形式是最有利于模型训练的。并行表示的架构如下图所示:
在这里插入图片描述
Retention的训练并行表示公式如下:
Q = ( X W Q ) ⊙ Θ , K = ( X W K ) ⊙ Θ ‾ , V = X W V Θ n = e i n θ , D n m = { γ n − m , n ⩾ m 0 , n < m R e t e n t i o n ( X ) = ( Q K T ⊙ D ) V \begin{gather*} Q=(XW_Q)\odot \Theta,K=(XW_K)\odot \overline{\Theta},V=XW_V \\ \Theta_n=e^{in\theta},D_{nm}=\begin{cases}\gamma^{n-m},&n\geqslant m \\0,&n<m \end{cases}\\ Retention(X)=(QK^T\odot D)V \end{gather*} Q=(XWQ)Θ,K=(XWK)Θ,V=XWVΘn=einθ,Dnm={γnm,0,nmnmRetention(X)=(QKTD)V

架构图中的“GN”是GroupNorm的缩写。

② Retention的推理循环表示
Retention模块能实现像RNN一样的高效推理,是因为隐含状态 S n S_n Sn。循环表示的架构如下图所示:
在这里插入图片描述
Retention的推理循环表示公式如下:
S n = γ S n − 1 + K n T V n R e t e n t i o n ( X n ) = Q n S n , n = 1 , . . . , ∣ x ∣ \begin{gather*} S_n=\gamma S_{n-1}+K_n^TV_n\\ Retention(X_n)=Q_nS_n,n=1,...,|x| \end{gather*} Sn=γSn1+KnTVnRetention(Xn)=QnSn,n=1,...,x

③ Retention的分块循环表示
可以将并行和循环结构进行结合,以提高长序列的训练速度:将输入序列分成不同的块,在块内采用并行结构,而块间信息则采用循环结构进行传递。

2.2 Gated Multi-Scale Retention

RetNet每一层中的Retention子模块也是分了h个头,每个头用不同的 W Q , W K , W V W_Q,W_K,W_V WQ,WK,WV参数,同时每个头都采用不同的 γ \gamma γ常量。

针对输入X,MSR层的计算公式如下:

γ = 1 − 2 − 5 − a r a n g e ( 0 , h ) ∈ R h h e a d i = R e t e n t i o n ( X , γ i ) Y = G r o u p N o r m h ( C o n c a t ( h e a d 1 , . . . h e a d h ) ) M S R ( X ) = ( s w i s h ( X W G ) ⊙ Y ) W O \begin{gather*} \gamma=1-2^{-5-arange(0,h)}\in \mathbb{R}^h\\ head_i=Retention(X,\gamma_i)\\ Y=GroupNorm_h(Concat(head_1,...head_h))\\ MSR(X)=(swish(XW_G)\odot Y)W_O \end{gather*} γ=125arange(0,h)Rhheadi=Retention(X,γi)Y=GroupNormh(Concat(head1,...headh))MSR(X)=(swish(XWG)Y)WO

其中,GroupNorm对每个头的输出进行归一化,swish是激活函数用来引入非线性。

3.实验

3.1 与Transformer的比较

在这里插入图片描述
该图展示了基于 Transformer 和 RetNet 的语言模型的验证集的PPL(PPL越小,说明这句话契合的越好)。展示了三种模型大小情况下的曲线,当模型大小大于2B时,RetNet的表现开始优于Transformer。
在这里插入图片描述
文章在广泛的下游任务上比较了语言模型。使用6.7B模型对zero-shot和few-shot学习情况下,不同的数据集进行了实验,RetNet 实现了与 Transformer 相当的性能。

3.2 训练成本

在这里插入图片描述
文章比较了Transformer和RetNet的训练速度和内存消耗,其中训练序列长度为8192。RetNet在训练过程中比Transformer具有更高的内存效率和更高的吞吐量。

3.3 推理成本

在这里插入图片描述
图(a):内存。由于 KV 缓存,Transformer 的内存成本呈线性增加。相比之下,即使对于长序列,RetNet 的内存消耗也保持一致。

图(b):吞吐量。Transformer 的吞吐量随着解码长度的增加而下降。相比之下,RetNet利用Rentention的循环表示,在解码过程中具有更高的吞吐量。

图(c):延迟。增加batch size大小会导致 Transformer 的延迟变大。相比之下,RetNet 的解码延迟优于 Transformer,并且在不同batch大小和输入长度下几乎保持一致。

3.4 与Transformer变体的比较

在这里插入图片描述
文章与其它高效的Transformer变体进行比较,包括Linear Transformer,RWKV,H3和Hyena。评价指标是PPL。RetNet 在不同数据集上的性能优于之前的方法。

4.总结

本研究提出了一个新的网络RetNet,支持并行表示、循环表示和分块循环表示。与 Transformer 相比,RetNet 实现了更好的推理效率(在内存、速度和延迟方面)、良好的训练并行性和良好的性能。

RetNet是一个极具创新性和前瞻性的工作,给自然语言处理和大模型架构设计带来了新的思路和突破。

  • 0
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值