本文主要是阅读论文《Fast Transformer Decoding: One Write-Head is All
You Need》的学习记录,这是一篇2019年的改善Multi-Head Attention带来的显存占用瓶颈问题的一种解决方案。
本文根据论文的撰写顺序,从Multi-Head Attention的代码开始,分析MHA在训练阶段和推理阶段的计算量和显存占用参数量;然后给出Multi-Query Attention的代码,分析MQA在训练阶段和推理阶段的计算量和显存占用情况。
提前预备:
- 假设读者了解decoder解码器的结构,包括L层Transformer,每层Transformer由一个self-attention层和两个MLP层构成,self-attention层默认为采用H头的注意力层。
- 了解在训练过程中,是通过mask来实现并行训练的。如果对这部分不是很熟悉,可以去看一下相关的知乎或者博客。
- 了解在推理过程中,模型的输入是当前时刻m的word embedding输入,输出是当前时刻的输出。在中间计算过程中需要用到过去所有时刻的输入来计算Key和Value,实现注意力机制,因此max_length越长的模型,推理的时间开销越大,因为模型需要一次次的重新计算key和value张量。因此在实现中往往采用KVcache的方式来进行加速。
关于KVcache的概念和计算可以参考知乎的文章《分析transformer模型的参数量、计算量、中间激活、KV cache》
在这里为了加快模型的推理速度,网络的输入为当前时刻
i
n
d
e
x
=
i
index=i
index=i的word embedding,模型会在self-attention层保存前面
(
i
−
1
)
(i-1)
(i−1)个
K
,
V
K,V
K,V的值
p
r
e
v
i
o
u
s
_
K
,
p
r
e
v
i
o
u
s
_
V
previous\_K,previous\_V
previous_K,previous_V,然后在得到当前
i
n
d
e
x
=
i
index=i
index=i的
n
e
w
_
k
,
n
e
w
_
v
new\_k,new\_v
new_k,new_v之后,
更新
p
r
e
v
i
o
u
s
_
K
=
c
o
n
c
a
t
(
p
r
e
v
i
o
u
s
_
k
,
n
e
w
_
k
)
previous\_K=concat(previous\_k,new\_k)
previous_K=concat(previous_k,new_k),
更新
p
r
e
v
i
o
u
s
_
V
=
c
o
n
c
a
t
(
p
r
e
v
i
o
u
s
_
v
,
n
e
w
_
v
)
previous\_V=concat(previous\_v,new\_v)
previous_V=concat(previous_v,new_v);
这样做的好处是加快当前token的推理速度,缺点是增加了模型的显存占用。且随着max_length的增加,这部分的
K
V
c
a
c
h
e
KV_{cache}
KVcache会成为推理的显存占用瓶颈。
一、 Attention_function
这部分主要介绍的是self-attention中的q,K,V是怎么实现注意力机制的。
需要注意的是这里的
q
q
q为
m
m
m时刻的hidden表示,而
K
,
V
K,V
K,V为序列中所有时刻的hidden表示,因此维度为
m
m
m.
输入:单个时刻的输入
q
q
q,维度为
(
1
,
k
)
(1,k)
(1,k),前面m个时刻的
K
K
K,维度为
(
m
,
k
)
(m,k)
(m,k),前面
m
m
m时刻的
V
V
V,维度为
(
m
,
v
)
(m,v)
(m,v)
输出:根据自注意力机制得到的输出
y
y
y
首先计算
q
q
q和
m
m
m个key的相似度,得到一个
m
m
m维的权重,然后将权重进行sigmoid标准化。这样得到的权重的每一个
W
i
W_i
Wi都表示q和第
i
i
i个Key的相似度或者说相关度,其值越大表示越相关。
将这个权重和对应的
v
a
l
u
e
value
value相乘,也就是赋予对应位置的
V
i
V_i
Vi更高的注意力权重。最后的返回值就是这个将带权重的Value叠加作用到最终结果的一个过程。
上面这个过程就是一个简单的attention过程。
二、Multi-head Attention
从上述简单的attention过程,引入Multihead的概念。
- 输入 x x x:张量维度为 d d d
- M M M为所有 m m m时刻的输入,维度为 [ m , d ] [m,d] [m,d]
- P _ q P\_q P_q实现的功能为:1. 提供一个线性映射转换,2. 转换后的 x x x reshape成 h h h个head,用于表征Multi-head这个概念。
- P _ k P\_k P_k, P _ v P\_v P_v的功能为:1. 为 m m m个时刻的输入向量 M M M提供一个线性映射变换,2. 将转换之后的 k e y key key, v a l u e value value reshape成Multi-head
- h h h表示head的数量, k k k为每个head中张量的维度,一般的实现中, h × k = d h \times k=d h×k=d,而且 k = v k=v k=v,也就是key和value的最后一个维度是一致的。
在论文的表示中,大写的
K
V
KV
KV表示张量名称,小写的
k
、
v
k、v
k、v表示长度。需要注意不要搞混了。
可以看到Multi-head Attention和上面的最简单形式的attention相比:
- 多了 q K V qKV qKV的计算方式,不再是直接存在 q , K , V q,K,V q,K,V,而是给出了 K V KV KV是通过 m m m个时刻的输入和投影矩阵 P _ k , P _ v P\_k,P\_v P_k,P_v计算得到, q q q是通过输入 x x x和 P _ q P\_q P_q计算得到的。
- 多了head的概念,将当前时刻的输入 x x x和 k e y key key, v a l u e value value的计算结果都reshape成h个集合。
- 输出也不是简单的通过value的权重和得到,而是需要增加一个输出映射 P o P_o Po
三、训练过程中的Multi-Head-attention
训练过程和上面这个过程的区别在于
- 训练过程多了batch这个维度,可以同时计算不同序列的输出。提高了显存利用率。
- 训练过程有GT,知道每个时刻的输入和输出,因此可以通过mask方式来实现并行计算,可以计算同一个序列上不同位置的输出。而不需要向推理那样等待前一个时刻的输出,来作为后一时刻的输入。
符号说明
- m m m表示的整个序列的输入,通过mask来实现遮蔽指定位置之后的其他输入。
- 这里的 n n n表示 n n n个时刻的输入(因为原因2,所以可以并行前向推理),也就是可以通过一次前向过程计算 n n n个时刻的输出。一般实现的时候, n = m n=m n=m。这里的 n n n一般为序列的长度,也就是会在依次前向过程中预测整个序列的每一个元素。通过 m a s k [ n , m ] mask[n,m] mask[n,m]来实现mask掉后面的输入。具体来说,假设 m = n = 10 m=n=10 m=n=10,那么对于时刻 i = 4 i=4 i=4来说,他的 m a s k = [ 1 , 1 , 1 , 1 , − i n f , − i n f , − i n f , . . . ] mask=[1,1,1,1,-inf,-inf,-inf,...] mask=[1,1,1,1,−inf,−inf,−inf,...],这样在计算KV计算只使用该时刻之前的输入,而不使用后面时刻的输入。
其余的和上面的基本没有什么变化
分析一下计算量,建立在
- b = b a t c h s i z e b=batchsize b=batchsize,
- n = m = n=m= n=m=序列的长度,
- h ∗ k = d h*k=d h∗k=d,
- k = v k=v k=v
这四条实践经验的基础上,我们知道矩阵乘法的计算量为最终输出的元素个数乘以每个元素的计算量。 我们以矩阵一次乘-加运算作为计算基础。
- Q K V QKV QKV的计算量都是 b n d 2 bnd^2 bnd2,那么总计算量为 3 b n d 2 3bnd^2 3bnd2
- logits的计算量为 b h n 2 k = b n 2 d bhn^2k=bn^2d bhn2k=bn2d
- weights的计算量为 (忽略一下)
- O计算量为 b h n v ∗ n = b h n 2 v = b n 2 d bhnv*n=bhn^2v=bn^2d bhnv∗n=bhn2v=bn2d
- Y的计算量 b n d ∗ d = b n d 2 bnd*d=bnd^2 bnd∗d=bnd2
综合来看,计算复杂度为 θ ( b n d 2 ) \theta (bnd^2) θ(bnd2),考虑到 n ⩽ d n \leqslant d n⩽d的情况:
计算参数量
- 1. X , M , Q , K , V , O , Y X,M,Q,K,V,O,Y X,M,Q,K,V,O,Y的参数都是bnd,那么总参数量为 6 b n d 6bnd 6bnd,
- 2.logits和weights总参数量为 2 b h n 2 2bhn^2 2bhn2,
- 3.参数 P _ v , P _ q , P _ k , P _ o P\_v,P\_q,P\_k,P\_o P_v,P_q,P_k,P_o的参数总量为 4 d 2 4d^2 4d2.
考虑到1的部分元素是中间激活值,不会保存,2的全部元素是中间激活值,不会保存。那么空间复杂度可以表示为O(bnd+d^2)
四、推理过程的MHA
这里着重讨论的是生成式模型,因为生成式模型的在某个时刻 i i i的输出依赖于从0时刻到当前时刻的所有输入。且在当前时刻的输出得到之前,无法计算后面时刻的输出,因此不像训练阶段那样可以并行执行前向过程。
但是前面时刻的输入主要是为了计算
K
V
KV
KV,
q
q
q只和当前时刻的输入有关。因此想要加速推理,可以将
K
V
KV
KV在推理过程中进行缓存保留。这也就是
K
V
c
a
c
h
e
KVcache
KVcache,这种方式存在的问题是
K
V
c
a
c
h
e
KVcache
KVcache是需要占用显存的。
这个公式和上面训练阶段的公式的区别在于
K
V
KV
KV的计算.
K V KV KV是通过之前时刻的 p r e v _ K prev\_K prev_K和当前时刻根据当前输入 M M M(我觉得这里的 M M M和 X X X表达的是一个意思,都是当前时刻的输入,为了和前面一章节的公式保持一致所以才这么写的。但是不确定我的这种理解对不对,欢迎讨论或者告知这里的 M M M是什么)计算得到的 n e w _ K new\_K new_K, n e w _ V new\_V new_V.
类似上一章节的计算复杂度推理,这里的考虑到不能并行计算的情况,那么对于长度为 n n n的序列,需要计算 n n n次,那么计算复杂度为 θ ( b n d 2 ) \theta (bnd^2) θ(bnd2);
空间复杂度也是同样的,在一次计算中,不考虑中间激活值,主要保存是是 K V c a c h e KVcache KVcache,以及矩阵 P _ q , P _ k , P _ v , P _ o P\_q,P\_k,P\_v,P\_o P_q,P_k,P_v,P_o.前者的参数量为 O ( b n 2 d ) O(bn^2d) O(bn2d),后者的参数量为 O ( d 2 ) O(d^2) O(d2)。可以看到随着序列长度 n n n的增加或者隐藏层深度 d d d的增加,这部分的显存占用将会称为瓶颈。
五、Multi-Query Attention
介绍完MHA之后,就可以介绍MQA了。
左边是MHA,右边是MQA,中间是没有提到过的GQA(Group-Query Attention,可以理解为不那么激进的MQA)。
可以看到MHA中每个head独立维护自己的 K V c a c h e KVcache KVcache。在MQA中则是同一层Transformer中的所有head维护一个 K V c a c h e KVcache KVcache,那么 K V c a c h e KVcache KVcache的显存可以明显的减少 h h h倍。
另外KV的投影映射矩阵
P
_
k
,
p
_
v
P\_k,p\_v
P_k,p_v也会同样的减少
h
h
h倍。
这里好像就没有什么好讲的啦~~和MHA的第三章节唯一的不同就是算法前面提到的两点。其他都是一样一样的
六、 推理阶段的MQA
x
x
x是某个时刻的输入,维度为
[
b
,
d
]
[b,d]
[b,d]
K
V
c
a
c
h
e
KVcache
KVcache的更新不再独立于每个head,而是所有head共享,因此
P
k
,
P
v
,
p
r
e
v
_
K
,
p
r
e
v
_
V
P_k,P_v,prev\_K,prev\_V
Pk,Pv,prev_K,prev_V的维度都少了一维。
七、实验对比
作者在英文转德文的数据集上进行了实验,采用的模型是具有2.1亿参数、6层layers,d_model=1024,head=8,且位置编码可学习的模型。
在对比MHA和MQA的时候,为了保持总参数量不变,会将MQA中的MLP层的hidden layer从1024*4增大为5440.
训练的效果评估指标如下所示,可以看到相比于MHA有轻微的性能衰减。PPL指标越小越好,BLEU是用于衡量翻译质量的数据集,越大越好。
从训练时间和推理时间上看,训练时有轻微的减少。对于单个training step从MHA的433ms减少到MQA的425ms。这里的训练上并没有很明显的提升。
在推理上的提升是很明显的。实验是在一个英文转德文的数据集上进行的,encoder的长度为128,decoder的长度也是128。encoder模块推理需要222ms,在MHA中,decoder需要47ms。在MQA上,encoder的时间为195ms,decoder只需要3.9ms
下表中时间是每个token上的时间。所以和上面我们提到的时间看起来不同,实际是一样的。