以 Attention Score 的计算为例
A
t
t
n
(
K
,
Q
,
V
)
=
S
o
f
t
m
a
x
(
Q
⋅
K
T
/
d
)
⋅
V
Attn(K,Q,V) = Softmax(Q\cdot K^T/\sqrt{d})\cdot V
Attn(K,Q,V)=Softmax(Q⋅KT/d)⋅V
咱姑且把 Softmax 和 Softmax里面的除以
d
\sqrt{d}
d 去掉(其运算时间复杂度小),表示为
A
t
t
n
(
K
,
Q
,
V
)
=
Q
⋅
K
T
⋅
V
Attn(K,Q,V) = Q\cdot K^T\cdot V
Attn(K,Q,V)=Q⋅KT⋅V
其中,
Q
,
K
,
V
∈
R
N
×
d
Q,K,V \in \mathbb{R}^{N\times d}
Q,K,V∈RN×d,
N
N
N 是token的数量,
d
d
d 是每个token的维度,一般认为
N
N
N>
d
d
d
Q ⋅ K T Q\cdot K^T Q⋅KT 从矩阵乘法上看维度变换是 N × d × d × N N\times d \times d \times N N×d×d×N,得到的矩阵维度是 N × N N\times N N×N,即得到的矩阵有 N 2 N^2 N2 个元素,每个元素需要经过d个元素相乘再相加得到(加权求和),所以 Q ⋅ K T Q\cdot K^T Q⋅KT 计算的时间复杂度为 O ( N 2 d ) O(N^2d) O(N2d)
- 总结一个快速得出结论的方法
如果你不想鸟我上面写的,你只需要按照这个规则来看
比如两个矩阵 M ⋅ N , M ∈ R m × n , M ∈ R n × k M\cdot N, M\in\mathbb{R}^{m\times n}, M\in\mathbb{R}^{n\times k} M⋅N,M∈Rm×n,M∈Rn×k
按照维度表示为 m × n × n × k m\times n \times n \times k m×n×n×k,只需要把中间的两个 n n n 删掉一个即可表示时间复杂度,为 O ( m × n × k ) O(m\times n\times k) O(m×n×k)