#Transformer#
近年来,Transformer已经成为了大型语言模型普遍采用的架构。虽然该架构在训练过程中能够实现出色的并行性和卓越性能,但在推理阶段却面临较高的成本。为了解决这一问题,文章引入了一种创新的网络结构,即RetNet。相对于传统的Transformer架构和多种变体,RetNet架构的优势是同时具备三个特点:训练可并行、推理成本低、具备良好的性能。
原文标题:Retentive Network: A Successor to Transformer for Large Language Models
作者信息:Yutao Sun,Li Dong,Shaohan Huang,Shuming Ma,Yuqing Xia,Jilong Xue,Jianyong Wang,Furu Wei(作者团队来自微软研究院和清华大学)
论文链接:https://arxiv.org/abs/2307.08621
1.介绍
在深度学习领域,循环神经网络(RNNs)以序列方式逐一处理输入数据,某个时间步骤上的输入处理取决于前一个时间步骤的隐藏状态,因此无法进行并行计算,从而降低了训练速度。而Transformer则采用了高度可并行化的自注意力机制,使得每个时间步的输出能够以Q、K、V矩阵的方式进行并行处理。不过,这种自注意力机制有助于Transformer在GPU上实现出色的并行性,但也导致了推理过程中的高成本。
研究者们一直在努力开发新一代架构,其目标是在保持训练并行性和Transformer性能的同时,实现高效的推理。要同时实现上述目标(即下图的“不可能三角“)是一项极具挑战性的任务。
该文提出了一个新的大语言模型自回归基础架构 Retentive Networks (RetNet),解决了“不可能三角”挑战。RetNet 在正中间,表示同时具备三个优点:推理成本低、训练可并行、良好的性能。而其他的架构 Linear Transformer、Recurrent Network 和 Transformer 都只能同时具备其中两个优点。
2.模型框架
RetNet 架构和 Transformer 类似,也是由L个相同的块堆叠而成。每一个 RetNet 块包含两个部分:一个multi-scale retention (MSR) 模块和一个feed-forward network (FFN) 模块。整体架构如下面的公式所示:
Y
l
=
M
S
R
(
L
N
(
X
l
)
)
+
X
l
X
l
+
1
=
F
F
N
(
L
N
(
Y
l
)
)
+
Y
l
\begin{gather*} Y^l=MSR(LN(X^l))+X^l \\ X^{l+1}=FFN(LN(Y^l))+Y^l \end{gather*}
Yl=MSR(LN(Xl))+XlXl+1=FFN(LN(Yl))+Yl
输入序列
{
x
i
}
i
=
1
∣
x
∣
\{x_i\}_{i=1}^{|x|}
{xi}i=1∣x∣通过一个词嵌入层转换为向量。然后使用打包好的嵌入
X
0
=
[
x
1
,
.
.
.
,
x
∣
x
∣
]
∈
R
∣
x
∣
×
d
m
o
d
e
l
X^0=[x_1,...,x_{|x|}]\in \mathbb{R}^{|x|\times d_{model}}
X0=[x1,...,x∣x∣]∈R∣x∣×dmodel作为输入,计算模型的输出
X
L
X_L
XL。
公式中LN(.)是LayerNorm。FFN部分,采用下面的公式计算:
F
F
N
(
X
)
=
g
e
l
u
(
X
W
1
)
W
2
FFN(X)=gelu(XW_1)W_2
FFN(X)=gelu(XW1)W2,其中
W
1
W_1
W1和
W
2
W_2
W2是参数矩阵。
在后面主要介绍MSR模块。
2.1 Retention
首先对词嵌入向量X序列的第n个时间步的向量乘以权重 ω \omega ω,得到投影 v n v_n vn: v ( n ) = X n ⋅ ω v v(n)=X_n \centerdot \omega_v v(n)=Xn⋅ωv。
然后类似Transformer架构,计算Q和K的投影: Q = X W Q , K = X W K Q=XW_Q,K=XW_K Q=XWQ,K=XWK
接着假设一个序列建模的问题,通过状态 s n s_n sn将 v n v_n vn映射为 o n o_n on向量,以递归的方式定义映射:
s n = A s n − 1 + K n T v n , A ∈ R d × d , K n ∈ R 1 × d o n = Q n s n = ∑ m = 1 n Q n A n − m K m T v m , Q n ∈ R 1 × d \begin{gather*} s_n=As_{n-1}+K_n^Tv_n,&A\in\mathbb{R}^{d\times d},K_n\in\mathbb{R}^{1\times d}\\ o_n=Q_ns_n=\sum_{m=1}^nQ_nA^{n-m}K_m^Tv_m,&Q_n\in\mathbb{R}^{1\times d} \end{gather*} sn=Asn−1+KnTvn,on=Qnsn=m=1∑nQnAn−mKmTvm,A∈Rd×d,Kn∈R1×dQn∈R1×d
其中,A是一个矩阵, K n K_n Kn表示时间步n对应的K投影,类似地, Q n Q_n Qn表示时间步n对应的Q投影。
接下来,利用对角化简化方程:
A
=
Λ
(
γ
e
i
θ
)
Λ
−
1
A=\Lambda (\gamma e^{i\theta})\Lambda^{-1}
A=Λ(γeiθ)Λ−1,得到新的
o
n
o_n
on表达式:
o
n
=
∑
m
=
1
n
Q
n
(
γ
e
i
θ
)
n
−
m
K
m
T
v
m
=
∑
m
=
1
n
(
Q
n
(
γ
e
i
θ
)
n
)
(
K
m
(
γ
e
i
θ
)
−
m
)
T
v
m
\begin{gather*} o_n=\sum_{m=1}^nQ_n(\gamma e^{i\theta})^{n-m}K_m^Tv_m\\ =\sum_{m=1}^n(Q_n(\gamma e^{i\theta})^n)(K_m(\gamma e^{i\theta})^{-m})^Tv_m \end{gather*}
on=m=1∑nQn(γeiθ)n−mKmTvm=m=1∑n(Qn(γeiθ)n)(Km(γeiθ)−m)Tvm
其中, Q n ( γ e i θ ) n Q_n(\gamma e^{i\theta})^n Qn(γeiθ)n, K m ( γ e i θ ) − m K_m(\gamma e^{i\theta})^{-m} Km(γeiθ)−m是xPOS,一种为transformer设计的位置编码。
再将
γ
\gamma
γ定义为一个标量,则可以将上述公式进一步简化为:
o
n
=
∑
m
=
1
n
γ
n
−
m
(
Q
n
e
i
n
θ
)
(
K
m
e
i
m
θ
)
†
v
m
o_n=\sum_{m=1}^n \gamma^{n-m}(Q_ne^{in\theta})(K_me^{im\theta})^{\dagger}v_m
on=m=1∑nγn−m(Qneinθ)(Kmeimθ)†vm
公式中的 † \dagger †表示共轭转置操作。通过以上的递归、对角化、标量简化这几个步骤,即得到了 Retention 最基本的形式。之后作者又给出了三种表示形式:
① Retention的训练并行表示
并行的形式是最有利于模型训练的。并行表示的架构如下图所示:
Retention的训练并行表示公式如下:
Q
=
(
X
W
Q
)
⊙
Θ
,
K
=
(
X
W
K
)
⊙
Θ
‾
,
V
=
X
W
V
Θ
n
=
e
i
n
θ
,
D
n
m
=
{
γ
n
−
m
,
n
⩾
m
0
,
n
<
m
R
e
t
e
n
t
i
o
n
(
X
)
=
(
Q
K
T
⊙
D
)
V
\begin{gather*} Q=(XW_Q)\odot \Theta,K=(XW_K)\odot \overline{\Theta},V=XW_V \\ \Theta_n=e^{in\theta},D_{nm}=\begin{cases}\gamma^{n-m},&n\geqslant m \\0,&n<m \end{cases}\\ Retention(X)=(QK^T\odot D)V \end{gather*}
Q=(XWQ)⊙Θ,K=(XWK)⊙Θ,V=XWVΘn=einθ,Dnm={γn−m,0,n⩾mn<mRetention(X)=(QKT⊙D)V
架构图中的“GN”是GroupNorm的缩写。
② Retention的推理循环表示
Retention模块能实现像RNN一样的高效推理,是因为隐含状态
S
n
S_n
Sn。循环表示的架构如下图所示:
Retention的推理循环表示公式如下:
S
n
=
γ
S
n
−
1
+
K
n
T
V
n
R
e
t
e
n
t
i
o
n
(
X
n
)
=
Q
n
S
n
,
n
=
1
,
.
.
.
,
∣
x
∣
\begin{gather*} S_n=\gamma S_{n-1}+K_n^TV_n\\ Retention(X_n)=Q_nS_n,n=1,...,|x| \end{gather*}
Sn=γSn−1+KnTVnRetention(Xn)=QnSn,n=1,...,∣x∣
③ Retention的分块循环表示
可以将并行和循环结构进行结合,以提高长序列的训练速度:将输入序列分成不同的块,在块内采用并行结构,而块间信息则采用循环结构进行传递。
2.2 Gated Multi-Scale Retention
RetNet每一层中的Retention子模块也是分了h个头,每个头用不同的 W Q , W K , W V W_Q,W_K,W_V WQ,WK,WV参数,同时每个头都采用不同的 γ \gamma γ常量。
针对输入X,MSR层的计算公式如下:
γ = 1 − 2 − 5 − a r a n g e ( 0 , h ) ∈ R h h e a d i = R e t e n t i o n ( X , γ i ) Y = G r o u p N o r m h ( C o n c a t ( h e a d 1 , . . . h e a d h ) ) M S R ( X ) = ( s w i s h ( X W G ) ⊙ Y ) W O \begin{gather*} \gamma=1-2^{-5-arange(0,h)}\in \mathbb{R}^h\\ head_i=Retention(X,\gamma_i)\\ Y=GroupNorm_h(Concat(head_1,...head_h))\\ MSR(X)=(swish(XW_G)\odot Y)W_O \end{gather*} γ=1−2−5−arange(0,h)∈Rhheadi=Retention(X,γi)Y=GroupNormh(Concat(head1,...headh))MSR(X)=(swish(XWG)⊙Y)WO
其中,GroupNorm对每个头的输出进行归一化,swish是激活函数用来引入非线性。
3.实验
3.1 与Transformer的比较
该图展示了基于 Transformer 和 RetNet 的语言模型的验证集的PPL(PPL越小,说明这句话契合的越好)。展示了三种模型大小情况下的曲线,当模型大小大于2B时,RetNet的表现开始优于Transformer。
文章在广泛的下游任务上比较了语言模型。使用6.7B模型对zero-shot和few-shot学习情况下,不同的数据集进行了实验,RetNet 实现了与 Transformer 相当的性能。
3.2 训练成本
文章比较了Transformer和RetNet的训练速度和内存消耗,其中训练序列长度为8192。RetNet在训练过程中比Transformer具有更高的内存效率和更高的吞吐量。
3.3 推理成本
图(a):内存。由于 KV 缓存,Transformer 的内存成本呈线性增加。相比之下,即使对于长序列,RetNet 的内存消耗也保持一致。
图(b):吞吐量。Transformer 的吞吐量随着解码长度的增加而下降。相比之下,RetNet利用Rentention的循环表示,在解码过程中具有更高的吞吐量。
图(c):延迟。增加batch size大小会导致 Transformer 的延迟变大。相比之下,RetNet 的解码延迟优于 Transformer,并且在不同batch大小和输入长度下几乎保持一致。
3.4 与Transformer变体的比较
文章与其它高效的Transformer变体进行比较,包括Linear Transformer,RWKV,H3和Hyena。评价指标是PPL。RetNet 在不同数据集上的性能优于之前的方法。
4.总结
本研究提出了一个新的网络RetNet,支持并行表示、循环表示和分块循环表示。与 Transformer 相比,RetNet 实现了更好的推理效率(在内存、速度和延迟方面)、良好的训练并行性和良好的性能。
RetNet是一个极具创新性和前瞻性的工作,给自然语言处理和大模型架构设计带来了新的思路和突破。