Attention总结

前提

因为在对大模型推理过程中,会对原始的注意力做各种修改以加快推理速度—即节约时间,也节约显存。在下文中我会对自己遇到的Attention做以总结。

Self-Attention

《Attention is All You Need》这篇文章中使用的Attention公式如下:
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d k ) T Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_k}})T Attention(Q,K,V)=Softmax(dk QKT)T
如果Q、K、 V是同一个x经过线性变化而来,则我们称之为Self-Attention,这样可以捕捉x内任意两个位置处的关系。

这里我们更关心的是Self-Attention的时间复杂度O(n^2),那么为什么是O(n^2)
答:记输入序列x的长度为n,经过embedding后一个位置的维度为d。例如,(32,512,768)的序列,其中32为batch_size,n=512为序列长度,d=769为embedding后的维度。

首先Q、K、 Vx经过线性变化而来,因此Q、K、 Vx同维度大小,都是n*d

  • 相似度计算 Q K T QK^T QKT: n × d n\times d n×d d × n d \times n d×n做矩阵乘法,得到 n × n n\times n n×n矩阵,时间复杂度为 O ( n 2 d ) O(n^2d) O(n2d)。这里可以给出一个结论:矩阵A大小为 n × m n\times m n×m,矩阵B大小为 m × n m\times n m×n,则矩阵 A × B A\times B A×B的时间复杂度为 O ( n × m × n ) O(n\times m \times n) O(n×m×n)。可以用以下代码进行验证:
int main() {
	// n=3, m=2
	int A[3][2] = {{1,2},{3,4},{5,6}};
	int B[2][3] = {{1,2,3},{4,5,6}};
	int C[3][3] = {0};
	// 计算A与B相乘
	for (int row = 0; row < 3; row++) { 
		for (int col = 0; col < 3; col++) {
			for (int k = 0; k < 2; k++) {
				C[row][col] = C[row][col] + A[row][k] * B[k][col];
			}
		}
	}
	for (int i = 0; i < 3; i++) {
		for (int j = 0; j < 3; j++) {
			printf("%d ", C[i][j]);
		}
		printf("\n");
	}
	return 0;
}

可以看到三层for循环,那这样时间复杂度一下就明了了。

  • Q × K T d k \frac{Q\times K^T}{\sqrt{d_k}} dk Q×KT的复杂度: Q K T QK^T QKT的大小为 n × n n\times n n×n,因此时间复杂度 O ( n 2 ) O(n^2) O(n2)

  • softmax计算:对大小为 n n n的一维向量做softmax,时间复杂度为 O ( n ) O(n) O(n) Q K T QK^T QKT的大小为 n × n n\times n n×n,共 n n n行;因此softmax的时间复杂度 O ( n 2 ) O(n^2) O(n2)

  • 乘以 V V V的权重: n × n n\times n n×n n × d n\times d n×d运算,得到 n × d n\times d n×d矩阵,复杂度 O ( n 2 d ) O(n^2d) O(n2d)

综上,Self-Attention的时间复杂度为: O ( n 2 d + n 2 + n 2 + n 2 d ) ≈ O ( n 2 d ) O(n^2d+n^2+n^2+n^2d)\approx O(n^2d) O(n2d+n2+n2+n2d)O(n2d)。如果将 d d d视作为一个常数,则就说自注意力的时间复杂度与序列长度成平方关系。

再来看看空间复杂度: Q K T QK^T QKT Q × K T d k \frac{Q\times K^T}{\sqrt{d_k}} dk Q×KT S o f t m a x ( Q × K T d k ) Softmax(\frac{Q\times K^T}{\sqrt{d_k}}) Softmax(dk Q×KT)大小都是 n × n n\times n n×n,所以空间复杂度为 O ( n 2 ) O(n^2) O(n2); S o f t m a x ( Q × K T d k ) V Softmax(\frac{Q\times K^T}{\sqrt{d_k}})V Softmax(dk Q×KT)V大小为 n × d n\times d n×d,所以空间复杂度为 O ( n d ) O(nd) O(nd)。所以说整个Self-Attention的空间复杂度为 O ( n 2 + n d ) O(n^2+nd) O(n2+nd)

Sparse Self-Attention

Sparse Self-Attention(稀疏注意力)来自2019年的这篇论文《Generating Long Sequences with Sparse Transformers》。下面对这篇文章稍微做一下解读。

1. Generating Long Sequences with Sparse Transformers

具体可以参考这篇文章。文章就是想办法降低自注意力的复杂度。通过一系列的实验发现,不需要对每个位置都和其他位置计算注意力,因为注意力权重高的地方只占一小部分,因此也就不需要那么密集的注意力。

2. Atrous self attention(空洞自注意力)

原始的自注意力如下图所示:

在这里插入图片描述
那我们现在跳着做自注意力,比如说:对上图的第5个位置,我们将其和第2,第5,第8…个位置计算注意力,这样是不是就可以把复杂度降下来。也就是说每个位置只会和 n / k n/k n/k个位置计算注意力, k k k是间隔。如下图所示:
在这里插入图片描述
这样的话时间复杂度就会降为 O ( n 2 / k ) O(n^2/k) O(n2/k),也就降为原来的 1 k \frac{1}{k} k1

3. local self-attention(局部自注意力)

还有一种方法是只计算当前位置与前后几个位置的注意力。如下图所示:
在这里插入图片描述
这样就是强制约束该位置只和自身以及前后 k k k个位置有关联,总共有 2 k + 1 2k+1 2k+1个位置,距离超过 k k k的注意力直接设置为0。
那么时间复杂度呢?由于每个位置只与 2 k + 1 2k+1 2k+1个位置有关联,所以时间复杂度为 O ( ( 2 k + 1 ) n ) ≈ O ( k n ) O((2k+1)n)\approx O(kn) O((2k+1)n)O(kn),可以看到现在的时间复杂度与序列长度成线性关系,显著降低了时间复杂度。

4. Sparse Attention(稀疏注意力)

Sparse Attention直接将空洞注意力与局部注意力结合起来,如下图所示:
在这里插入图片描述
除过箭头连接的位置,其他位置的注意力都置为0。这样一来Attention就具有局部紧密相关和远程稀疏相关的特性,paper中说这样的复杂度为 O ( n n ) O(n\sqrt{n}) O(nn )。这里先留一个坑,为什么这样一来复杂度就是 O ( n n ) O(n\sqrt{n}) O(nn )呢?日后来填以及介绍更多的注意力。

参考链接

  1. https://zhuanlan.zhihu.com/p/473389061
  2. https://zhuanlan.zhihu.com/p/661804092
  3. https://zhuanlan.zhihu.com/p/260928791
  4. https://zhuanlan.zhihu.com/p/473389061
  5. https://blog.csdn.net/qq_39463175/article/details/111818717
  6. 苏剑林. (Jul. 27, 2019). 《为节约而生:从标准Attention到稀疏Attention 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/6853
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值