Raki的读paper小记:Retentive Network: A Successor to Transformer for Large Language Models

82 篇文章 10 订阅

Abstract&Introduction&Related Work

  • 研究任务
    语言模型的基础架构
  • 已有方法和相关工作
    • S4,H3,Hyena,Linear Transformer
    • 用核函数近似注意力,以便将自回归推理重写为循环形式
    • 回归到使用循环模型进行高效推理,但牺牲了训练并行性。为了弥补这一点,使用元素级操作[PAA+23]进行加速,但同时损害了表示能力和性能
    • 尝试用其他机制取代注意力,例如S4[GGR21]及其变体[DFS+22,PMN+23]
  • 面临挑战
  • 创新思路
    • RetNet = linear attention + rope + 显式衰减(即 γ \gamma γ
  • 实验结论
    实现了不可能三角,实现了O(1)推理
    在这里插入图片描述

在这里插入图片描述

Retentive Networks

先来看一下Retention跟Attention的区别,首先第一眼感觉retention有点像RNN和LSTM

attention的计算方式是QK做矩阵乘法,使用query和key计算权重分布,对value加权

retention使用了一个线性衰减参数 γ \gamma γ,使用了一个状态向量S

给定输入 X ∈ R ∣ x ∣ × d m o d e l X\in\mathbb{R}^{|x|\times d_{\mathrm{model}}} XRx×dmodel,我们将其投影到一维函数 v ( n ) = X n ⋅ w V v(n) = X_n · w_V v(n)=XnwV。考虑一个序列建模问题,通过状态 s n s_n sn v ( n ) v(n) v(n)映射为 o ( n ) o(n) o(n), 为简单起见,用 v n 、 o n v_n、o_n vnon 表示 v ( n ) v(n) v(n) o ( n ) o(n) o(n), 以递归方式形式化映射过程:

s n = A s n − 1 + K n T v n , A ∈ R d × d , K n ∈ R 1 × d o n = Q n s n = ∑ m = 1 n Q n A n − m K m T v m , Q n ∈ R 1 × d \begin{aligned}s_n&=As_{n-1}+K_n^\mathsf{T}v_n,&A\in\mathbb{R}^{d\times d},K_n\in\mathbb{R}^{1\times d}\\o_n&=Q_ns_n=\sum_{m=1}^nQ_nA^{n-m}K_m^\mathsf{T}v_m,&Q_n\in\mathbb{R}^{1\times d}\end{aligned} snon=Asn1+KnTvn,=Qnsn=m=1nQnAnmKmTvm,ARd×d,KnR1×dQnR1×d

将vn映射到状态向量sn,并通过线性变换来递归地编码序列信息。接下来使投影 Q n Q_n Qn K n K_n Kn变得与内容相关: Q = X W Q , K = X W K Q=XW_{Q},\quad K=XW_{K} Q=XWQ,K=XWK

将矩阵A对角化 A = Λ ( γ e i θ ) Λ − 1 A=\Lambda(\gamma e^{i\theta})\Lambda^{-1} A=Λ(γeiθ)Λ1,由   A n − m = Λ ( γ e i θ ) n − m Λ − 1 \text{ }A^{n-m}= \Lambda(\gamma e^{i\theta})^{n-m}\Lambda^{-1}  Anm=Λ(γeiθ)nmΛ1得到 o ( n ) o(n) o(n)改进后的输出表达式:
o n = ∑ m = 1 n Q n ( γ e i θ ) n − m K m ⊺ v m = ∑ m = 1 n ( Q n ( γ e i θ ) n ) ( K m ( γ e i θ ) − m ) ⊺ v m \begin{aligned} o_{n}& =\sum_{m=1}^nQ_n(\gamma e^{i\theta})^{n-m}K_m^\intercal v_m \\ &=\sum_{m=1}^n(Q_n(\gamma e^{i\theta})^n)(K_m(\gamma e^{i\theta})^{-m})^\intercal v_m \end{aligned} on=m=1nQn(γeiθ)nmKmvm=m=1n(Qn(γeiθ)n)(Km(γeiθ)m)vm

Q n ( γ e i θ ) n , K m ( γ e i θ ) − m Q_{n}(\gamma e^{i\theta})^{n},K_{m}(\gamma e^{i\theta})^{-m} Qn(γeiθ)n,Km(γeiθ)m 在xPos中被大家周知,这里使用了类似Transformer中提出的相对位置嵌入。我们进一步将γ简化为一个标量: o n = ∑ m = 1 n γ n − m ( Q n e i n θ ) ( K m e i m θ ) † v m o_n=\sum_{m=1}^n\gamma^{n-m}(Q_ne^{in\theta})(K_me^{im\theta})^\dagger v_m on=m=1nγnm(Qneinθ)(Kmeimθ)vm

The Parallel Representation of Retention

最后retention的公式如下:
Q = ( X W Q ) ⊙ Θ , K = ( X W K ) ⊙ Θ ‾ , V = X W V Θ n = e i n θ , D n m = { γ n − m , n ≥ m 0 , n < m Retention ( X ) = ( Q K ⊺ ⊙ D ) V Q=(XW_Q)\odot\Theta,\quad K=(XW_K)\odot\overline{\Theta},\quad V=XW_V\\\Theta_n=e^{in\theta},\quad D_{nm}=\left\{\begin{matrix}\gamma^{n-m},&n\geq m\\0,&n<m\end{matrix}\right.\\\text{Retention}(X)=(QK^\intercal\odot D)V Q=(XWQ)Θ,K=(XWK)Θ,V=XWVΘn=einθ,Dnm={γnm,0,nmn<mRetention(X)=(QKD)V
在这里插入图片描述
也可以写成循环的方式:
S n = γ S n − 1 + K n ⊺ V n Retention ( X n ) = Q n S n , n = 1 , ⋯   , ∣ x ∣ \begin{array}{l}S_n=\gamma S_{n-1}+K_n^\intercal V_n\\\text{Retention}(X_n)=Q_nS_n,\quad n=1,\cdots,|x|\end{array} Sn=γSn1+KnVnRetention(Xn)=QnSn,n=1,,x

保持机制的分块循环表示

为了加速训练,特别是对于较长的序列,这里采用了一种并行表示和循环表示的混合形式。将输入序列分成块。在每个块内,采用并行表示进行计算。跨块信息通过循环表示传递。具体来说,设B表示块长度。我们通过以下方式计算第i个块的保持输出:
Q [ i ] = Q B i : B ( i + 1 ) , K [ i ] = K B i : B ( i + 1 ) , V [ i ] = V B i : B ( i + 1 ) R i = K [ i ] ⊺ V [ i ] + γ B R i − 1 Retention ⁡ ( X [ i ] ) = ( Q [ i ] K [ i ] T ⊙ D ) V [ i ] ⏟ Inner-Chunk + ( Q [ i ] R i ) ⊙ ξ ⏟ Cross-Chunk , ξ i j = γ i + 1 \begin{aligned} &Q_{[i]}=Q_{Bi:B(i+1)},\quad K_{[i]}=K_{Bi:B(i+1)},\quad V_{[i]}=V_{Bi:B(i+1)} \\ R_{i}& =K_{[i]}^\intercal V_{[i]}+\gamma^BR_{i-1} \\ \operatorname{Retention}(X_{[i]})& =\underbrace{(Q_{[i]}K_{[i]}^{\mathsf{T}}\odot D)V_{[i]}}_\text{Inner-Chunk}+\underbrace{(Q_{[i]}R_i)\odot\xi}_\text{Cross-Chunk},\quad\xi_{ij}=\gamma^{i+1} \end{aligned} RiRetention(X[i])Q[i]=QBi:B(i+1),K[i]=KBi:B(i+1),V[i]=VBi:B(i+1)=K[i]V[i]+γBRi1=Inner-Chunk (Q[i]K[i]TD)V[i]+Cross-Chunk (Q[i]Ri)ξ,ξij=γi+1

跟attention的示意图对比:
在这里插入图片描述

Gated Multi-Scale Retention

在每一层中,我们使用 h = d m o d e l / d h = d_{model}/d h=dmodel/d个retention heads,其中d是头的维度。这些头使用不同的参数矩阵 W Q 、 W K 、 W V ∈ R d m o d e l × d m o d e l W_Q、W_K、W_V\in\mathbb{R}^{d_{\mathrm{model}}\times d_{\mathrm{model}}} WQWKWVRdmodel×dmodel 。多尺度保持(MSR)为每个头分配不同的γ。为了简单起见,我们在不同层之间设置相同的γ,并将它们固定。此外,我们添加了一个Swish门[HG16, RZL17]来增加保持层的非线性。具体而言,给定输入X,我们定义该层为:
γ = 1 − 2 − 5 − a r a n g e ( 0 , h ) ∈ R h h e a d i = Retention ⁡ ( X , γ i ) Y = G r o u p Norm h ( C o n c a t ( h e a d 1 , ⋯   , h e a d h ) ) M S R ( X ) = ( s w i s h ( X W G ) ⊙ Y ) W O \begin{aligned} \gamma& =1-2^{-5-\mathrm{arange}(0,h)}\in\mathbb{R}^h \\ \mathrm{head}_{i}& =\operatorname{Retention}(X,\gamma_i) \\ \boldsymbol{Y}& =\mathrm{Group}\text{Norm}_h(\mathrm{Concat}(\mathrm{head}_1,\cdots,\mathrm{head}_h)) \\ \mathop{\mathrm{MSR}}(X)& =(\mathrm{swish}(XW_{G})\odot Y)W_{O} \end{aligned} γheadiYMSR(X)=125arange(0,h)Rh=Retention(X,γi)=GroupNormh(Concat(head1,,headh))=(swish(XWG)Y)WO

组归一化(GroupNorm)[WH18]用于对每个头的输出进行归一化,遵循在[SPP+19]中提出的SubLN。请注意,不同的γ尺度会导致不同的方差统计数据。因此,我们分别对头的输出进行归一化。保持机制的伪代码总结在图4中:
在这里插入图片描述
归一化:
在这里插入图片描述
packed embeddings X 0 = [ x 1 , ⋯   , x ∣ x ∣ ] ∈ R ∣ x ∣ × d m o d e l X^0=[\boldsymbol{x}_1,\cdots,\boldsymbol{x}_{|x|}]\in\mathbb{R}^{|x|\times d_{\mathrm{model}}} X0=[x1,,xx]Rx×dmodel 作为模型输入,计算模型输出:

Y l = M S R ( L N ( X l ) ) + X l X l + 1 = F F N ( L N ( Y l ) ) + Y l \begin{aligned}Y^l&=\mathrm{MSR}(\mathrm{LN}(X^l))+X^l\\X^{l+1}&=\mathrm{FFN}(\mathrm{LN}(Y^l))+Y^l\end{aligned} YlXl+1=MSR(LN(Xl))+Xl=FFN(LN(Yl))+Yl

LN是layernorm
F F N ( X ) = g e l u ( X W 1 ) W 2 FFN(X) = gelu(XW_1)W_2 FFN(X)=gelu(XW1)W2

在训练过程中,我们使用并行和分块循环表示。序列内或块内的并行化有效地利用GPU加速计算。更重要的是,分块循环特别适用于长序列训练,无论是在FLOPs还是内存消耗方面都非常高效

在推理过程中,我们采用循环表示,这非常适合自回归解码。O(1)复杂度降低了内存和推理延迟,同时实现了相同的结果

与其他方法的差异:
在这里插入图片描述

Experiments

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

Conclusions

在本研究中,我们提出了保持网络(RetNet)用于序列建模,实现了多种表示方式,即并行、循环和分块循环。相比Transformer,RetNet在推理效率(内存、速度和延迟方面)、训练并行化和竞争性能方面都取得了显著优势。上述优势使得RetNet成为大型语言模型中理想的Transformer继任者,特别是考虑到O(1)推理复杂度带来的部署优势。未来,我们希望在模型规模[CDH+22]和训练步骤方面扩展RetNet。此外,通过压缩长期记忆,保持机制能够高效地与结构化提示[HSB+22b]结合使用。我们还将使用RetNet作为骨干架构来训练多模式大型语言模型[HSB+22a,HDW+23,PWD+23]。此外,我们对在各种边缘设备上部署RetNet模型,如移动手机等,也充满兴趣。

Remark

用线性注意力+rope位置编码+权重衰减,得到了在小任务上的效率和效果全方面吊打transformer的结果,但是长距离建模能力暂时未知,毕竟transformer的建模是能看到前面所有模块,而retnet是由前一个状态转移得来,有cherry picking可能,不过不管如何,都是一个了不起的工作,如tianxiang哥所说,文本的diffusion时期可能很快就要到来,一举成为取代attention方式decoder的超强方法,甚至于不用regressive方式的解码(but我自己觉得自回归就是我心目中的终极奥义了,至少是终极奥义的组成部分,不可能被完全不使用),期待后续工作推进在长序列上的建模能力,不管如何,在low resource的情况下,retnet这种架构必然会大展身手,毕竟transformer的推理代价过高导致很多情况都是不可能使用(以目前的硬件水平 不过我很相信老黄),嗯 希望成为或者引出下一个划时代的工作

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值