简介
微软亚洲研究院上线了大模型新架构的论文“Retentive Network: A Successor to Transformer for Large Language Models”,该基础架构采用了新的 Retention 机制来代替 Attention,向 Transformer 发起挑战!相对于 Transformer 架构的优势是同时具备:训练可并行、推理成本低和良好的性能,不可能三角。
论文中给出一个很形象的示意图,RetNet 在正中间表示同时具备三个优点,而其他的架构 Linear Transformer、Recurrent Network 和 Transformer 都只能同时具备其中两个优点。
实验结果
论文给出的 RetNet 和 Transformer 的对比实验结果:
当输入序列长度增加的时候,RetNet 的 GPU 显存占用一直是稳定的和权值差不多,而 Transformer 则是和输入长度成正比。
首先看红色线和紫色线,都是输入长度在 8192 下,RetNet 和 Transformer 推理延时的对比。
可以看到当 batch size 增加的时候, RetNet 的推理延时也还是很稳定,而 Transformer 的推理延时则是和 batch size 成正比。
而 Transformer 即使是输入长度缩小到 1024 ,推理延时也还是比 RetNet 要高。
LLM困惑度
p
p
l
∈
[
1
,
∞
]
ppl \in [1,\infty]
ppl∈[1,∞]上的实验也比之前的一些工作要好。
在67亿参数规模上的LLM实验效果也比Transformer好
训练效率的对比,看得出来去掉softmax之后有一定的提升。
RetNet 架构解读
与标准自注意力机制相比,保持机制有几大特点:
引入位置相关的指数衰减项取代 softmax,简化了计算,同时使前步的信息以衰减的形式保留下来。
引入复数空间表达位置信息,取代绝对或相对位置编码,容易转换为递归形式。
另外,保持机制使用多尺度的衰减率,增加了模型的表达能力,并利用 GroupNorm 的缩放不变性来提高 Retention 层的数值精度。
RetNet 架构和 Transformer 类似,也是堆叠
L
L
L层同样的模块,每个模块内部包含两个子模块:一个 multi-scale retention(MSR)和一个 feed-forward network (FFN)。
下面详细解读一下这个 retention 子模块。
首先给定一个输入序列
{
x
i
}
i
=
1
∣
x
∣
\left\{ x_i \right\}_{i=1}^{|x|}
{xi}i=1∣x∣:
x
=
x
1
⋅
⋅
⋅
x
∣
x
∣
x = x_{1} \cdot \cdot \cdot x_{|x|}
x=x1⋅⋅⋅x∣x∣
其中
∣
x
∣
|x|
∣x∣表示序列的长度。然后输入序列首先经过 embedding 层得到词嵌入向量:
X
0
=
[
x
1
,
…
,
x
∣
x
∣
]
∈
R
∣
x
∣
×
d
X^{0} = [x_{1}, \dots, x_{|x|}] \in \mathbb{R}^{|x| \times d}
X0=[x1,…,x∣x∣]∈R∣x∣×d
其中
d
d
d 表示隐含层的维度。
Retention 机制
首先对给定输入词嵌入向量序列
X
∈
R
∣
x
∣
×
d
X \in \mathbb{R}^{|x| \times d}
X∈R∣x∣×d中的每个时间步
n
n
n 的向量
X
n
∈
R
1
×
d
X_{n} \in \mathbb{R}^{1 \times d}
Xn∈R1×d都乘以权值
w
v
∈
R
d
×
d
w_{v} \in \mathbb{R}^{d \times d}
wv∈Rd×d得到
v
n
∈
R
1
×
d
v_{n} \in \mathbb{R}^{1 \times d}
vn∈R1×d:
v
n
=
X
n
⋅
w
V
v_{n} = X_{n} \cdot w_{V}
vn=Xn⋅wV
然后同样有类似 Transformer 架构的
Q
Q
Q 和
K
K
K 的投影:
Q
=
X
W
Q
,
K
=
X
W
K
Q = XW_{Q}, K = XW_{K}
Q=XWQ,K=XWK
其中
W
Q
∈
R
d
×
d
W_{Q} \in \mathbb{R}^{d \times d}
WQ∈Rd×d,
W
K
∈
R
d
×
d
W_{K} \in \mathbb{R}^{d \times d}
WK∈Rd×d是需要学习的权值。
接着假设现在有一个序列建模的问题,通过状态
s
n
∈
R
d
×
d
s_{n} \in \mathbb{R}^{d \times d}
sn∈Rd×d将
v
n
v_n
vn映射为
o
n
o_n
on向量。首先来看论文中给出的映射方式定义:
s
n
=
A
s
n
−
1
+
K
n
T
⋅
v
n
s_{n} = As_{n-1}+ K_n^{T} \cdot v_{n}
sn=Asn−1+KnT⋅vn
o
n
=
Q
n
⋅
s
n
=
∑
m
=
1
n
Q
n
A
n
−
m
K
m
T
v
m
o_{n} = Q_{n} \cdot s_{n} = \sum_{m=1}^{n} Q_{n} A^{n-m} {K_{m}}^{T} v_{m}
on=Qn⋅sn=m=1∑nQnAn−mKmTvm
其中
A
∈
R
d
×
d
A \in \mathbb{R}^{d \times d}
A∈Rd×d是一个矩阵,
K
n
∈
R
1
×
d
K_{n} \in \mathbb{R}^{1 \times d}
Kn∈R1×d表示时间步
n
n
n 对应的
K
K
K 投影,则
K
T
v
n
∈
R
d
×
d
K^{T}v_{n} \in \mathbb{R}^{d \times d}
KTvn∈Rd×d。同样
Q
n
∈
R
1
×
d
Q_{n} \in \mathbb{R}^{1 \times d}
Qn∈R1×d表示时间步
n
n
n 对应的
Q
Q
Q 投影。
那么上面公式中的
o
n
o_n
on计算公式是怎么得出来呢,下面详细解释一下,首先将
Q
n
s
n
Q_{n}s_{n}
Qnsn展开:
Q
n
s
n
=
Q
n
(
A
s
n
−
1
+
K
n
T
v
n
)
Q_{n}s_{n} = Q_{n}(A s_{n-1} + {K_{n}}^{T} v_{n})
Qnsn=Qn(Asn−1+KnTvn)
=
Q
n
(
A
(
A
s
n
−
2
+
K
n
−
1
T
v
n
−
1
)
+
K
n
T
v
n
)
= Q_{n}(A (A s_{n-2} + {K_{n-1}}^{T} v_{n-1}) + {K_{n}}^{T} v_{n})
=Qn(A(Asn−2+Kn−1Tvn−1)+KnTvn)
=
Q
n
(
A
2
s
n
−
2
+
A
1
K
n
−
1
T
v
n
−
1
+
A
0
K
n
T
v
n
)
= Q_{n}(A^2 s_{n-2} + A^1{K_{n-1}}^{T} v_{n-1} + A^0{K_{n}}^{T} v_{n})
=Qn(A2sn−2+A1Kn−1Tvn−1+A0KnTvn)
=
Q
n
(
A
2
(
s
n
−
3
+
A
1
K
n
−
2
T
v
n
−
2
)
+
A
1
K
n
−
1
T
v
n
−
1
+
A
0
K
n
T
v
n
)
= Q_{n}(A^2( s_{n-3} + A^1{K_{n-2}}^{T} v_{n-2}) + A^1{K_{n-1}}^{T} v_{n-1} + A^0{K_{n}}^{T} v_{n})
=Qn(A2(sn−3+A1Kn−2Tvn−2)+A1Kn−1Tvn−1+A0KnTvn)
=
Q
n
(
A
3
s
n
−
3
+
A
2
K
n
−
2
T
v
n
−
2
+
A
1
K
n
−
1
T
v
n
−
1
+
A
0
K
n
T
v
n
)
= Q_{n}(A^3 s_{n-3} + A^2{K_{n-2}}^{T} v_{n-2} + A^1{K_{n-1}}^{T} v_{n-1} + A^0{K_{n}}^{T} v_{n})
=Qn(A3sn−3+A2Kn−2Tvn−2+A1Kn−1Tvn−1+A0KnTvn)
其中
A
0
A^0
A0表示单位矩阵(主对角线元素为1,其余元素为0的方阵)。然后我们假定
s
0
s_0
s0 为初始状态元素为全0的矩阵,则有:
s
1
=
A
s
0
+
K
1
T
v
1
=
K
1
T
v
1
s_{1} = A s_{0} + {K_{1}}^{T} v_{1}= {K_{1}}^{T} v_{1}
s1=As0+K1Tv1=K1Tv1
再继续上述推导过程:
Q
n
s
n
=
Q
n
(
A
3
s
n
−
3
+
A
2
K
n
−
2
T
v
n
−
2
+
A
1
K
n
−
1
T
v
n
−
1
+
A
0
K
n
T
v
n
)
Q_{n}s_{n} = Q_{n}(A^3 s_{n-3} + A^2{K_{n-2}}^{T} v_{n-2} + A^1{K_{n-1}}^{T} v_{n-1} + A^0{K_{n}}^{T} v_{n})
Qnsn=Qn(A3sn−3+A2Kn−2Tvn−2+A1Kn−1Tvn−1+A0KnTvn)
=
Q
n
(
A
n
−
(
n
−
3
)
s
n
−
3
+
A
n
−
(
n
−
2
)
K
n
−
2
T
v
n
−
2
+
A
n
−
(
n
−
1
)
K
n
−
1
T
v
n
−
1
+
A
n
−
(
n
−
0
)
K
n
T
v
n
)
= Q_{n}(A^{n-(n-3)} s_{n-3} + A^{n-(n-2)}{K_{n-2}}^{T} v_{n-2} + A^{n-(n-1)}{K_{n-1}}^{T} v_{n-1} + A^{n-(n-0)}{K_{n}}^{T} v_{n})
=Qn(An−(n−3)sn−3+An−(n−2)Kn−2Tvn−2+An−(n−1)Kn−1Tvn−1+An−(n−0)KnTvn)
所以根据上述推导过程和条件归纳可得:
Q
n
s
n
=
Q
n
(
A
n
−
1
s
1
+
A
n
−
2
K
2
T
v
2
+
.
.
.
+
A
n
−
(
n
−
2
)
K
n
−
2
T
v
n
−
2
+
A
n
−
(
n
−
1
)
K
n
−
1
T
v
n
−
1
+
A
n
−
(
n
−
0
)
K
n
T
v
n
)
Q_{n}s_{n}= Q_{n}(A^{n-1}s_1+A^{n-2}{K_{2}}^{T} v_{2} +...+ A^{n-(n-2)}{K_{n-2}}^{T} v_{n-2} + A^{n-(n-1)}{K_{n-1}}^{T} v_{n-1} + A^{n-(n-0)}{K_{n}}^{T} v_{n})
Qnsn=Qn(An−1s1+An−2K2Tv2+...+An−(n−2)Kn−2Tvn−2+An−(n−1)Kn−1Tvn−1+An−(n−0)KnTvn)
=
Q
n
(
A
n
−
1
K
1
T
v
1
+
A
n
−
2
K
2
T
v
2
+
.
.
.
+
A
n
−
(
n
−
2
)
K
n
−
2
T
v
n
−
2
+
A
n
−
(
n
−
1
)
K
n
−
1
T
v
n
−
1
+
A
n
−
(
n
−
0
)
K
n
T
v
n
)
= Q_{n}(A^{n-1}{K_{1}}^{T} v_{1}+A^{n-2}{K_{2}}^{T} v_{2} +...+ A^{n-(n-2)}{K_{n-2}}^{T} v_{n-2} + A^{n-(n-1)}{K_{n-1}}^{T} v_{n-1} + A^{n-(n-0)}{K_{n}}^{T} v_{n})
=Qn(An−1K1Tv1+An−2K2Tv2+...+An−(n−2)Kn−2Tvn−2+An−(n−1)Kn−1Tvn−1+An−(n−0)KnTvn)
=
∑
m
=
1
n
Q
n
A
n
−
m
K
m
T
v
m
= \sum_{m=1}^{n} Q_{n} A^{n-m} {K_{m}}^{T} v_{m}
=m=1∑nQnAn−mKmTvm
然后我们来看一下
A
A
A 矩阵是什么,论文中定义了
A
A
A 是一个可对角化的矩阵,具体定义为:
A
=
Λ
(
γ
e
i
θ
)
Λ
−
1
A= \Lambda(\gamma e^{i\theta}) \Lambda^{-1}
A=Λ(γeiθ)Λ−1
其中
γ
,
θ
∈
R
d
\gamma,\theta \in \mathbb{R}^{d}
γ,θ∈Rd都是
d
d
d维向量,
Λ
\Lambda
Λ是可逆矩阵。
欧拉公式:
e
i
x
=
c
o
s
x
+
i
⋅
s
i
n
x
e^{ix} = cos x + i·sin x
eix=cosx+i⋅sinx
其中
x
x
x 表示任意实数,
e
e
e 是自然对数的底数,
i
i
i 是复数中的虚数单位,也可以表示为实部
c
o
s
x
cos \ x
cos x ,虚部
s
i
n
x
sin \ x
sin x 的一个复数,欧拉公式建立了指数函数、三角函数和复数之间的桥梁。
而这里
θ
\theta
θ 是一个
d
d
d 维向量:
θ
=
[
θ
1
,
θ
2
,
.
.
.
,
θ
d
−
1
,
θ
d
]
\theta=[\theta_1,\theta_2,...,\theta_{d-1},\theta_d]
θ=[θ1,θ2,...,θd−1,θd]
则
e
i
θ
e^{i\theta}
eiθ 也就是将向量元素两两一组表示分别表示为复数的实部和虚部:
e
i
θ
=
[
c
o
s
θ
1
,
s
i
n
θ
2
,
.
.
.
,
c
o
s
θ
d
−
1
,
s
i
n
θ
d
]
e^{i\theta}=[cos\theta_1,sin\theta_2,...,cos\theta_{d-1},sin\theta_d]
eiθ=[cosθ1,sinθ2,...,cosθd−1,sinθd]
然后
γ
e
i
θ
\gamma e^{i\theta}
γeiθ 就是一个对角矩阵,对角元素的值就对应将
γ
\gamma
γ 和
e
i
θ
e^{i\theta}
eiθ 转成复数向量相乘再将结果转回实数向量的结果。
现在我们知道了矩阵
A
A
A 的构成就能得到:
A
n
−
m
=
(
Λ
(
γ
e
i
θ
)
Λ
−
1
)
n
−
m
A^{n-m}= (\Lambda(\gamma e^{i\theta}) \Lambda^{-1})^{n-m}
An−m=(Λ(γeiθ)Λ−1)n−m
这里因为
Λ
\Lambda
Λ 是可逆矩阵则有性质
Λ
Λ
−
1
=
Λ
−
1
Λ
=
I
\Lambda \Lambda^{-1}=\Lambda^{-1}\Lambda =I
ΛΛ−1=Λ−1Λ=I
其中
I
I
I为单位矩阵,则将
n
−
m
n-m
n−m 次方展开:
A
n
−
m
=
Λ
(
γ
e
i
θ
)
Λ
−
1
Λ
(
γ
e
i
θ
)
Λ
−
1
.
.
.
.
Λ
(
γ
e
i
θ
)
Λ
−
1
A^{n-m}= \Lambda(\gamma e^{i\theta}) \Lambda^{-1} \Lambda(\gamma e^{i\theta}) \Lambda^{-1}.... \Lambda(\gamma e^{i\theta}) \Lambda^{-1}
An−m=Λ(γeiθ)Λ−1Λ(γeiθ)Λ−1....Λ(γeiθ)Λ−1
就是
n
−
m
n-m
n−m 个
Λ
(
γ
e
i
θ
)
Λ
−
1
\Lambda(\gamma e^{i\theta}) \Lambda^{-1}
Λ(γeiθ)Λ−1 矩阵相乘,中间相邻的
Λ
−
1
Λ
\Lambda^{-1}\Lambda
Λ−1Λ 都消掉了,所以可得:
A
n
−
m
=
Λ
(
γ
e
i
θ
)
n
−
m
Λ
−
1
A^{n-m}= \Lambda(\gamma e^{i\theta}) ^{n-m}\Lambda^{-1}
An−m=Λ(γeiθ)n−mΛ−1
然后我们回到计算
o
n
o_n
on 的公式:
o
n
=
∑
m
=
1
n
Q
n
A
n
−
m
K
m
T
v
m
o_{n} = \sum_{m=1}^{n} Q_{n} A^{n-m} {K_{m}}^{T} v_{m}
on=m=1∑nQnAn−mKmTvm
=
∑
m
=
1
n
Q
n
(
Λ
(
γ
e
i
θ
)
n
−
m
Λ
−
1
)
K
m
T
v
m
= \sum_{m=1}^{n} Q_{n}( \Lambda(\gamma e^{i\theta}) ^{n-m}\Lambda^{-1} ){K_{m}}^{T} v_{m}
=m=1∑nQn(Λ(γeiθ)n−mΛ−1)KmTvm
=
∑
m
=
1
n
X
n
W
Q
(
Λ
(
γ
e
i
θ
)
n
−
m
Λ
−
1
)
(
X
m
W
K
)
T
v
m
= \sum_{m=1}^{n} X_{n}W_{Q}( \Lambda(\gamma e^{i\theta}) ^{n-m}\Lambda^{-1} )(X_{m}W_{K})^{T} v_{m}
=m=1∑nXnWQ(Λ(γeiθ)n−mΛ−1)(XmWK)Tvm
=
∑
m
=
1
n
X
n
W
Q
(
Λ
(
γ
e
i
θ
)
n
−
m
Λ
−
1
)
X
m
T
W
K
T
v
m
= \sum_{m=1}^{n} X_{n}W_{Q}( \Lambda(\gamma e^{i\theta}) ^{n-m}\Lambda^{-1} )X_{m}^{T}W_{K}^{T} v_{m}
=m=1∑nXnWQ(Λ(γeiθ)n−mΛ−1)XmTWKTvm
接着论文中提出把
Λ
\Lambda
Λ 吸收进
W
Q
W_Q
WQ 和
W
K
W_K
WK 也就是
W
Q
Λ
W_{Q} \Lambda
WQΛ 和
Λ
−
1
W
K
T
\Lambda^{-1}W_{K}^{T}
Λ−1WKT 分别用
W
Q
W_Q
WQ 和
W
K
T
W_{K}^{T}
WKT 替代当作学习的权值,那么可得:
o
n
=
∑
m
=
1
n
Q
n
(
γ
e
i
θ
)
n
−
m
K
m
T
v
m
o_{n} = \sum_{m=1}^{n} Q_{n}(\gamma e^{i\theta}) ^{n-m} K_{m}^{T} v_{m}
on=m=1∑nQn(γeiθ)n−mKmTvm
=
∑
m
=
1
n
Q
n
(
γ
e
i
θ
)
n
(
γ
e
i
θ
)
−
m
K
m
T
v
m
= \sum_{m=1}^{n} Q_{n}(\gamma e^{i\theta}) ^{n}(\gamma e^{i\theta}) ^{-m} K_{m}^{T} v_{m}
=m=1∑nQn(γeiθ)n(γeiθ)−mKmTvm
=
∑
m
=
1
n
Q
n
(
γ
e
i
θ
)
n
(
K
m
γ
e
i
θ
)
−
m
)
T
v
m
= \sum_{m=1}^{n} Q_{n}(\gamma e^{i\theta}) ^{n}(K_{m}\gamma e^{i\theta}) ^{-m} )^{T} v_{m}
=m=1∑nQn(γeiθ)n(Kmγeiθ)−m)Tvm
=
∑
m
=
1
n
Q
n
(
γ
n
e
i
θ
n
)
(
K
m
γ
−
m
e
i
θ
(
−
m
)
)
T
v
m
= \sum_{m=1}^{n} Q_{n}(\gamma ^{n} e^{i\theta n})(K_{m}\gamma ^{-m}e^{i\theta(-m)})^{T} v_{m}
=m=1∑nQn(γneiθn)(Kmγ−meiθ(−m))Tvm
接着将公式简化,将
γ
\gamma
γ 改为一个实数常量,那么可得:
o
n
=
∑
m
=
1
n
Q
n
(
γ
n
e
i
θ
n
)
(
K
m
γ
−
m
e
i
θ
(
−
m
)
)
T
v
m
o_{n} = \sum_{m=1}^{n} Q_{n}(\gamma ^{n} e^{i\theta n})(K_{m}\gamma ^{-m}e^{i\theta(-m)})^{T} v_{m}
on=m=1∑nQn(γneiθn)(Kmγ−meiθ(−m))Tvm
=
∑
m
=
1
n
γ
n
−
m
(
Q
n
e
i
θ
n
)
(
K
m
e
i
θ
(
−
m
)
)
T
v
m
= \sum_{m=1}^{n} \gamma ^{n-m}(Q_{n}e^{i\theta n})(K_{m}e^{i\theta(-m)})^{T} v_{m}
=m=1∑nγn−m(Qneiθn)(Kmeiθ(−m))Tvm
在继续推导前,先来仔细看一下
e
i
θ
(
−
m
)
e^{i\theta(-m)}
eiθ(−m),借助欧拉公式展开:
e
i
θ
(
−
m
)
=
[
c
o
s
−
m
θ
1
,
s
i
n
−
m
θ
2
,
.
.
.
,
c
o
s
−
m
θ
d
−
1
,
s
i
n
−
m
θ
d
]
e^{i\theta(-m)}=[cos-m\theta_1,sin-m\theta_2,...,cos-m\theta_{d-1},sin-m\theta_d]
eiθ(−m)=[cos−mθ1,sin−mθ2,...,cos−mθd−1,sin−mθd]
三角函数性质:
c
o
s
(
−
θ
)
=
c
o
s
θ
cos(-\theta)=cos\theta
cos(−θ)=cosθ
s
i
n
(
−
θ
)
=
−
s
i
n
θ
sin(-\theta)=-sin\theta
sin(−θ)=−sinθ
则有:
e
i
θ
(
−
m
)
=
[
c
o
s
m
θ
1
,
−
s
i
n
m
θ
2
,
.
.
.
,
c
o
s
m
θ
d
−
1
,
−
s
i
n
m
θ
d
]
e^{i\theta(-m)}=[cosm\theta_1,-sinm\theta_2,...,cosm\theta_{d-1},-sinm\theta_d]
eiθ(−m)=[cosmθ1,−sinmθ2,...,cosmθd−1,−sinmθd]
转为复数形式表示就是:
e
i
θ
(
−
m
)
=
[
c
o
s
m
θ
1
,
−
i
s
i
n
m
θ
2
,
.
.
.
,
c
o
s
m
θ
d
−
1
,
−
i
s
i
n
m
θ
d
]
e^{i\theta(-m)}=[cosm\theta_1,-i\ sinm\theta_2,...,cosm\theta_{d-1},-i \ sinm\theta_d]
eiθ(−m)=[cosmθ1,−i sinmθ2,...,cosmθd−1,−i sinmθd]
刚好就对应
e
i
θ
m
e^{i\theta m}
eiθm 的共轭
e
i
θ
m
=
[
c
o
s
m
θ
1
,
+
i
s
i
n
m
θ
2
,
.
.
.
,
c
o
s
m
θ
d
−
1
,
+
i
s
i
n
m
θ
d
]
e^{i\theta m}=[cosm\theta_1,+i\ sinm\theta_2,...,cosm\theta_{d-1},+i \ sinm\theta_d]
eiθm=[cosmθ1,+i sinmθ2,...,cosmθd−1,+i sinmθd]
所以可得:
o
n
=
∑
m
=
1
n
γ
n
−
m
(
Q
n
e
i
θ
n
)
(
K
m
e
i
θ
m
)
T
v
m
o_n = \sum_{m=1}^{n} \gamma ^{n-m}(Q_{n}e^{i\theta n})(K_{m}e^{i\theta m})^{T} v_{m}
on=m=1∑nγn−m(Qneiθn)(Kmeiθm)Tvm
=
∑
m
=
1
n
γ
n
−
m
(
Q
n
e
i
θ
n
)
(
K
m
e
i
θ
m
)
†
v
m
= \sum_{m=1}^{n} \gamma ^{n-m}(Q_{n}e^{i\theta n})(K_{m}e^{i\theta m})^{\dagger} v_{m}
=m=1∑nγn−m(Qneiθn)(Kmeiθm)†vm
其中
†
\dagger
† 表示共轭转置操作。
Retention 的训练并行表示
首先回顾单个时间步
n
n
n 的输出
o
n
o_n
on 的计算公式如下:
o
n
=
∑
m
=
1
n
γ
n
−
m
(
Q
n
e
i
θ
n
)
(
K
m
e
i
θ
m
)
†
v
m
o_n= \sum_{m=1}^{n} \gamma ^{n-m}(Q_{n}e^{i\theta n})(K_{m}e^{i\theta m})^{\dagger} v_{m}
on=m=1∑nγn−m(Qneiθn)(Kmeiθm)†vm
而所有时间步的输出是可以并行计算的,用矩阵形式表达如下:
(
(
Q
⊙
Θ
)
(
K
⊙
Θ
ˉ
)
T
⊙
D
)
V
((Q\odot\Theta)(K\odot\bar{\Theta})^T\odot D)V
((Q⊙Θ)(K⊙Θˉ)T⊙D)V
其中
V
∈
R
∣
x
∣
×
d
V \in \mathbb{R}^{|x| \times d}
V∈R∣x∣×d,而
⊙
\odot
⊙ 表示两个矩阵逐元素相乘,
Q
∈
R
∣
x
∣
×
d
Q \in \mathbb{R}^{|x| \times d}
Q∈R∣x∣×d和
K
∈
R
∣
x
∣
×
d
K \in \mathbb{R}^{|x| \times d}
K∈R∣x∣×d 每一行对应一个时间步的
q
q
q 和
k
k
k 向量。
而
Θ
∈
R
∣
x
∣
×
d
\Theta \in \mathbb{R}^{|x| \times d}
Θ∈R∣x∣×d 每一行对应向量
e
i
θ
n
,
n
=
1
,
.
.
.
,
∣
x
∣
e^{i\theta n},n=1,...,|x|
eiθn,n=1,...,∣x∣ 。
Θ
ˉ
∈
R
∣
x
∣
×
d
\bar{\Theta} \in \mathbb{R}^{|x| \times d}
Θˉ∈R∣x∣×d就是对应
Θ
\Theta
Θ 矩阵的共轭,也就是将
Θ
\Theta
Θ 矩阵每一行改为复数的共轭形式。
而
D
∈
R
∣
x
∣
×
∣
x
∣
D \in \mathbb{R}^{|x| \times |x|}
D∈R∣x∣×∣x∣ 矩阵是一个下三角矩阵,其中第
n
n
n 行第
m
m
m 列的元素计算方式:
D
n
m
=
γ
n
−
m
,
n
>
=
m
D_{nm}=\gamma^{n-m},n>=m
Dnm=γn−m,n>=m
D
n
m
=
0
,
n
<
m
D_{nm}=0,n<m
Dnm=0,n<m
标准Transformers是:
s
o
f
t
m
a
x
(
Q
K
T
/
d
k
)
V
softmax(QK^T/\sqrt{d_k})V
softmax(QKT/dk)V
而RetNet是:
(
Q
K
T
⊙
D
)
V
(QK^T\odot D)V
(QKT⊙D)V
除去和尺度、稀疏化等相关的softmax之外,只差了一个
D
D
D,可以看到,这个
D
D
D就是上面公式,合并记忆节点的线性加权权重系数。
相当于一个逐渐衰减的不可训练的相对位置编码。
Retention 的推理循环表示
推理阶段的循环表示论文中定义如下:
S
n
=
γ
S
n
−
1
+
K
n
T
V
n
S_n=\gamma S_{n-1}+{K_n}^TV_n
Sn=γSn−1+KnTVn
R
e
t
e
n
t
i
o
n
(
X
n
)
=
Q
n
S
n
Retention(X_n)=Q_nS_n
Retention(Xn)=QnSn
单个时间步
n
n
n 的输出
o
n
o_n
on 的计算公式:
o
n
=
∑
m
=
1
n
γ
n
−
m
(
Q
n
e
i
θ
n
)
(
K
m
e
i
θ
m
)
†
v
m
o_n= \sum_{m=1}^{n} \gamma ^{n-m}(Q_{n}e^{i\theta n})(K_{m}e^{i\theta m})^{\dagger} v_{m}
on=m=1∑nγn−m(Qneiθn)(Kmeiθm)†vm
=
Q
n
e
i
θ
n
(
∑
m
=
1
n
γ
n
−
m
(
K
m
e
i
θ
m
)
†
v
m
)
= Q_{n}e^{i\theta n}(\sum_{m=1}^{n}\gamma ^{n-m}(K_{m}e^{i\theta m})^{\dagger} v_{m})
=Qneiθn(m=1∑nγn−m(Kmeiθm)†vm)
=
Q
n
e
i
θ
n
(
γ
n
−
n
(
K
n
e
i
θ
n
)
†
v
n
+
(
∑
m
=
1
n
−
1
γ
n
−
m
(
K
m
e
i
θ
m
)
†
v
m
)
= Q_{n}e^{i\theta n}(\gamma ^{n-n}(K_{n}e^{i\theta n})^{\dagger} v_{n}+(\sum_{m=1}^{n-1}\gamma ^{n-m}(K_{m}e^{i\theta m})^{\dagger} v_{m})
=Qneiθn(γn−n(Kneiθn)†vn+(m=1∑n−1γn−m(Kmeiθm)†vm)
=
Q
n
e
i
θ
n
(
(
K
n
e
i
θ
n
)
†
v
n
+
(
∑
m
=
1
n
−
1
γ
n
−
m
(
K
m
e
i
θ
m
)
†
v
m
)
= Q_{n}e^{i\theta n}((K_{n}e^{i\theta n})^{\dagger} v_{n}+(\sum_{m=1}^{n-1}\gamma ^{n-m}(K_{m}e^{i\theta m})^{\dagger} v_{m})
=Qneiθn((Kneiθn)†vn+(m=1∑n−1γn−m(Kmeiθm)†vm)
=
Q
n
e
i
θ
n
(
(
K
n
e
i
θ
n
)
†
v
n
+
γ
(
K
n
−
1
e
i
θ
(
n
−
1
)
)
†
v
n
+
∑
m
=
1
n
−
2
γ
n
−
m
−
1
(
K
m
e
i
θ
m
)
†
v
m
)
= Q_{n}e^{i\theta n}((K_{n}e^{i\theta n})^{\dagger} v_{n}+\gamma(K_{n-1}e^{i\theta (n-1)})^{\dagger} v_{n}+\sum_{m=1}^{n-2}\gamma ^{n-m-1}(K_{m}e^{i\theta m})^{\dagger} v_{m})
=Qneiθn((Kneiθn)†vn+γ(Kn−1eiθ(n−1))†vn+m=1∑n−2γn−m−1(Kmeiθm)†vm)
上述公式最后一步和推理阶段循环表示公式中各个元素的对应关系是:
Q
n
=
Q
n
e
i
θ
n
Q_n=Q_ne^{i\theta n}
Qn=Qneiθn
S
n
−
1
=
(
K
n
−
1
e
i
θ
(
n
−
1
)
)
†
v
n
+
∑
m
=
1
n
−
2
γ
n
−
m
−
1
(
K
m
e
i
θ
m
)
†
v
m
S_{n-1}=(K_{n-1}e^{i\theta (n-1)})^{\dagger} v_{n}+\sum_{m=1}^{n-2}\gamma ^{n-m-1}(K_{m}e^{i\theta m})^{\dagger} v_{m}
Sn−1=(Kn−1eiθ(n−1))†vn+m=1∑n−2γn−m−1(Kmeiθm)†vm
K
n
T
V
n
=
(
K
n
e
i
θ
n
)
†
v
n
{K_n}^TV_n=(K_{n}e^{i\theta n})^{\dagger} v_{n}
KnTVn=(Kneiθn)†vn
对应论文中的图示:
图中的
G
N
GN
GN 表示 GroupNorm。
普通Transformer在解码过程中都有Key和Value,所以解码需要
O
(
n
)
O(n)
O(n)复杂度。
而RetNet解码只需要
O
(
1
)
O(1)
O(1)复杂度,这是因为其将前序节点的表示做成了一个向量,类似于RNN一样循环更新,用Key和Value计算出一个向量之后,去更新权重向量
S
S
S,相当于RNN中的记忆向量。而输出就是Query与
S
S
S的计算结果。
标准RNN是当前输入
X
t
X_t
Xt和记忆向量
S
t
−
1
S_{t-1}
St−1接起来过全连接层,即
S
t
=
f
(
U
⋅
X
t
+
W
⋅
S
t
−
1
)
S_t=f(U·X_t+W·S_{t-1})
St=f(U⋅Xt+W⋅St−1),RetNet是用一个实数
γ
\gamma
γ做权重,进行线性加权。
可以看到在推理阶段,RetNet 在计算当前时间步
n
n
n 的输出
O
n
O_n
On 只依赖于上一个时间步产出的状态矩阵
S
n
−
1
S_{n-1}
Sn−1。
其实就是把计算顺序改了一下,先计算的
K
n
K_n
Kn 和
V
n
V_n
Vn 的相乘然后一直累加到状态矩阵
S
n
S_n
Sn 上,最后再和
Q
n
Q_n
Qn 相乘。
而不是像 Transformer 架构那样,每个时间步的计算要先算
Q
n
Q_n
Qn 和前面所有时间步的
K
K
K 相乘得到 attention 权值再和
V
V
V 相乘求和,这样就需要一直保留历史的
K
K
K 和
V
V
V。
Gated Multi-Scale Retention
然后 RetNet 每一层中的 Retention 子模块其实也是分了
h
h
h 个头,每个头用不同的
W
Q
,
W
K
,
W
V
∈
R
d
×
d
W_{Q}, W_{K}, W_{V} \in \mathbb{R}^{d \times d}
WQ,WK,WV∈Rd×d 参数,同时每个头都采用不同的
γ
\gamma
γ 常量,这也是 Multi-Scale Retention 名称的来由。
则对输入
X
X
X, MSR 层的输出是:
γ
=
1
−
2
−
5
−
a
r
a
n
g
e
(
0
,
h
)
∈
R
h
\gamma = 1-2^{-5-arange(0,h)} \in \mathbb{R}^{h }
γ=1−2−5−arange(0,h)∈Rh
h
e
a
d
i
=
R
e
t
e
n
t
i
o
n
(
X
,
γ
i
)
head_i= Retention(X,\gamma_i)
headi=Retention(X,γi)
Y
=
G
r
o
u
p
N
o
r
m
h
(
C
o
n
c
a
t
(
h
e
a
d
1
,
.
.
.
,
h
e
a
d
h
)
)
Y=GroupNorm_h(Concat(head_1,...,head_h))
Y=GroupNormh(Concat(head1,...,headh))
M
S
R
(
X
)
=
(
s
w
i
s
h
(
X
W
G
)
⊙
Y
)
W
O
MSR(X)=(swish(XW_G)\odot\ Y)W_O
MSR(X)=(swish(XWG)⊙ Y)WO
其中
W
G
,
W
O
∈
R
d
∗
h
×
d
∗
h
W_{G}, W_{O} \in \mathbb{R}^{d*h \times d*h}
WG,WO∈Rd∗h×d∗h ,
s
w
i
s
h
swish
swish 是激活函数用来生成门控阈值,还有由于每个头均采用不同的
γ
\gamma
γ,所以每个头的输出要单独做 normalize 之后再 concat。