文章目录
前言
回顾self-attention,对self-attention机制类别进行总结
How to make self-attention efficient?
什么是self-attention?
self-attention的存在是为了处理输入是sequence的状态。假设sequence长度为N,输入sequence后会产生N个 key vector和N个query vector,两两之间做dot product,得到NXN的Attention Matrix,做weight sum。
self-attention往往是一个巨大network其中一个module,比如transform。self-attention的运算量和N平方成正比,当N的长度非常长时,整个transform主要的运算来自于self-attention,此时加快self-attention才有帮助。
Sparse attention by human knowledge
Local Attention /Truncated Attention
有时候在做attention时不一定要看整个sequence,可能只要看左右的邻居的资讯,Attention Matrix如下图所示:
灰色代表把attention的weight设为0,只要计算蓝色部分的attention的weight,加速运算;这样的方法叫Local Attention。Local Attention的问题是只能看到小范围内的信息,和CNN没差别,它只能加快运算方法但结果不一定是好的。
Stride Attention
上面讲的是看邻居的咨询,Stride Attention看比较远的资讯,Attention Matrix如下图所示:
只计算绿色的部分,空的格数根据具体问题而设定。
Global Attention
在原来的sequence中加上特殊的token,会从sequence中的每个token收集资讯,Global Attention有两种做法:在原来的sequence中选一些token作为特殊的token或者外加额外的token,Attention Matrix如下图所示:
头两个位置是special token,其他的位置为0,不是special token,则彼此间不做attention。
上面是人为指定哪些位置要算self-attention,哪些位置直接估0。在Attention Matrix中有些位置attention的weight大,有些值小,把值特小的直接设为0,结果不会差很大。
怎么快速的估算哪些位置可能有大的attention值,进行详细的计算,小的值直接设0?
Clustering
步骤一:
把query和key拿出来,根据query和key的相近程度做Clustering,比较近的属于同一个Clustering,远的属于不同的Clustering,在例子中有4个Clustering。
步骤二:
query和key落在同样的Clustering中,才计算attention的weight,其他位置直接设为0。
Sinkhorn Sorting Network
Sinkhorn Sorting Network做的是哪些地方要不要做attention是用learned来决定的。
input sequence的每一个位置通过一个NN产生一个vector,vector的长度和sequence一样,vector拼起来的大小和Attention Matrix一样是nxn。
在linformer的文章中发现Attention Matrix中有很多redundant columns 冗余列,许多列重复,把重复的列拿掉,计算剩下的可以加快计算速度。
具体做法:
从N个key中挑选有代表性的K个出来,计算NXK的Attention Matrix。有N个value,挑出K个具有代表性的value vector,把K个key对第一个query算出的attention weight对这K个vector做gradient decent得到第一个位置的output。
怎么选择有代表性的key?
- 用CNN扫过sequence,长度变短,短的sequence作为代表性的key。
- input可以看作一个dxN的矩阵,乘上N X K 的矩阵,得到dXK的矩阵,每一列代表一个key。
attention可以看作一连串矩阵的相乘,计算过程可以简化。
input是I ,output是O 。
忽略softmax,可以写成
先算V X KT和先算KT X Q的结果是一样的,但计算量却不同。
先算KT X Q要做N X d X N 次的乘法,得到结果为N X N的矩阵A ,A 再乘 V 需要d’ X N X N的乘法次数;总次数是(d+d’)N2。先算V X KT要d’ x N x d次,变成d’ x d 的矩阵,再乘Q要d’ x d x N 次乘法,总次数是2d’dN次。