前提
因为在对大模型推理过程中,会对原始的注意力做各种修改以加快推理速度—即节约时间,也节约显存。在下文中我会对自己遇到的Attention做以总结。
Self-Attention
《Attention is All You Need》这篇文章中使用的Attention公式如下:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
S
o
f
t
m
a
x
(
Q
K
T
d
k
)
T
Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_k}})T
Attention(Q,K,V)=Softmax(dkQKT)T
如果Q、K、 V
是同一个x
经过线性变化而来,则我们称之为Self-Attention,这样可以捕捉x
内任意两个位置处的关系。
这里我们更关心的是Self-Attention的时间复杂度O(n^2)
,那么为什么是O(n^2)
?
答:记输入序列x
的长度为n
,经过embedding后一个位置的维度为d。例如,(32,512,768)的序列,其中32为batch_size,n=512为序列长度,d=769为embedding后的维度。
首先Q、K、 V
是x
经过线性变化而来,因此Q、K、 V
与x
同维度大小,都是n*d
。
- 相似度计算 Q K T QK^T QKT: n × d n\times d n×d与 d × n d \times n d×n做矩阵乘法,得到 n × n n\times n n×n矩阵,时间复杂度为 O ( n 2 d ) O(n^2d) O(n2d)。这里可以给出一个结论:矩阵A大小为 n × m n\times m n×m,矩阵B大小为 m × n m\times n m×n,则矩阵 A × B A\times B A×B的时间复杂度为 O ( n × m × n ) O(n\times m \times n) O(n×m×n)。可以用以下代码进行验证:
int main() {
// n=3, m=2
int A[3][2] = {{1,2},{3,4},{5,6}};
int B[2][3] = {{1,2,3},{4,5,6}};
int C[3][3] = {0};
// 计算A与B相乘
for (int row = 0; row < 3; row++) {
for (int col = 0; col < 3; col++) {
for (int k = 0; k < 2; k++) {
C[row][col] = C[row][col] + A[row][k] * B[k][col];
}
}
}
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
printf("%d ", C[i][j]);
}
printf("\n");
}
return 0;
}
可以看到三层for
循环,那这样时间复杂度一下就明了了。
-
Q × K T d k \frac{Q\times K^T}{\sqrt{d_k}} dkQ×KT的复杂度: Q K T QK^T QKT的大小为 n × n n\times n n×n,因此时间复杂度为 O ( n 2 ) O(n^2) O(n2)。
-
softmax计算:对大小为 n n n的一维向量做softmax,时间复杂度为 O ( n ) O(n) O(n), Q K T QK^T QKT的大小为 n × n n\times n n×n,共 n n n行;因此softmax的时间复杂度为 O ( n 2 ) O(n^2) O(n2)。
-
乘以 V V V的权重: n × n n\times n n×n 与 n × d n\times d n×d运算,得到 n × d n\times d n×d矩阵,复杂度为 O ( n 2 d ) O(n^2d) O(n2d)
综上,Self-Attention的时间复杂度为: O ( n 2 d + n 2 + n 2 + n 2 d ) ≈ O ( n 2 d ) O(n^2d+n^2+n^2+n^2d)\approx O(n^2d) O(n2d+n2+n2+n2d)≈O(n2d)。如果将 d d d视作为一个常数,则就说自注意力的时间复杂度与序列长度成平方关系。
再来看看空间复杂度: Q K T QK^T QKT、 Q × K T d k \frac{Q\times K^T}{\sqrt{d_k}} dkQ×KT、 S o f t m a x ( Q × K T d k ) Softmax(\frac{Q\times K^T}{\sqrt{d_k}}) Softmax(dkQ×KT)大小都是 n × n n\times n n×n,所以空间复杂度为 O ( n 2 ) O(n^2) O(n2); S o f t m a x ( Q × K T d k ) V Softmax(\frac{Q\times K^T}{\sqrt{d_k}})V Softmax(dkQ×KT)V大小为 n × d n\times d n×d,所以空间复杂度为 O ( n d ) O(nd) O(nd)。所以说整个Self-Attention的空间复杂度为 O ( n 2 + n d ) O(n^2+nd) O(n2+nd)。
Sparse Self-Attention
Sparse Self-Attention(稀疏注意力)来自2019年的这篇论文《Generating Long Sequences with Sparse Transformers》。下面对这篇文章稍微做一下解读。
1. Generating Long Sequences with Sparse Transformers
具体可以参考这篇文章。文章就是想办法降低自注意力的复杂度。通过一系列的实验发现,不需要对每个位置都和其他位置计算注意力,因为注意力权重高的地方只占一小部分,因此也就不需要那么密集的注意力。
2. Atrous self attention(空洞自注意力)
原始的自注意力如下图所示:
那我们现在跳着做自注意力,比如说:对上图的第5个位置,我们将其和第2,第5,第8…个位置计算注意力,这样是不是就可以把复杂度降下来。也就是说每个位置只会和
n
/
k
n/k
n/k个位置计算注意力,
k
k
k是间隔。如下图所示:
这样的话时间复杂度就会降为
O
(
n
2
/
k
)
O(n^2/k)
O(n2/k),也就降为原来的
1
k
\frac{1}{k}
k1。
3. local self-attention(局部自注意力)
还有一种方法是只计算当前位置与前后几个位置的注意力。如下图所示:
这样就是强制约束该位置只和自身以及前后
k
k
k个位置有关联,总共有
2
k
+
1
2k+1
2k+1个位置,距离超过
k
k
k的注意力直接设置为0。
那么时间复杂度呢?由于每个位置只与
2
k
+
1
2k+1
2k+1个位置有关联,所以时间复杂度为
O
(
(
2
k
+
1
)
n
)
≈
O
(
k
n
)
O((2k+1)n)\approx O(kn)
O((2k+1)n)≈O(kn),可以看到现在的时间复杂度与序列长度成线性关系,显著降低了时间复杂度。
4. Sparse Attention(稀疏注意力)
Sparse Attention直接将空洞注意力与局部注意力结合起来,如下图所示:
除过箭头连接的位置,其他位置的注意力都置为0。这样一来Attention就具有局部紧密相关和远程稀疏相关的特性,paper中说这样的复杂度为
O
(
n
n
)
O(n\sqrt{n})
O(nn)。这里先留一个坑,为什么这样一来复杂度就是
O
(
n
n
)
O(n\sqrt{n})
O(nn)呢?日后来填以及介绍更多的注意力。
参考链接
- https://zhuanlan.zhihu.com/p/473389061
- https://zhuanlan.zhihu.com/p/661804092
- https://zhuanlan.zhihu.com/p/260928791
- https://zhuanlan.zhihu.com/p/473389061
- https://blog.csdn.net/qq_39463175/article/details/111818717
- 苏剑林. (Jul. 27, 2019). 《为节约而生:从标准Attention到稀疏Attention 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/6853