本文主要是阅读论文《Fast Transformer Decoding: One Write-Head is All
You Need》的学习记录,这是一篇2019年的改善Multi-Head Attention带来的显存占用瓶颈问题的一种解决方案。
本文根据论文的撰写顺序,从Multi-Head Attention的代码开始,分析MHA在训练阶段和推理阶段的计算量和显存占用参数量;然后给出Multi-Query Attention的代码,分析MQA在训练阶段和推理阶段的计算量和显存占用情况。
提前预备:
- 假设读者了解decoder解码器的结构,包括L层Transformer,每层Transformer由一个self-attention层和两个MLP层构成,self-attention层默认为采用H头的注意力层。
- 了解在训练过程中,是通过mask来实现并行训练的。如果对这部分不是很熟悉,可以去看一下相关的知乎或者博客。
- 了解在推理过程中,模型的输入是当前时刻m的word embedding输入,输出是当前时刻的输出。在中间计算过程中需要用到过去所有时刻的输入来计算Key和Value,实现注意力机制,因此max_length越长的模型,推理的时间开销越大,因为模型需要一次次的重新计算key和value张量。因此在实现中往往采用KVcache的方式来进行加速。
关于KVcache的概念和计算可以参考知乎的文章《分析transformer模型的参数量、计算量、中间激活、KV cache》
在这里为了加快模型的推理速度,网络的输入为当前时刻 i n d e x = i index=i index=i的word embedding,模型会在self-attention层保存前面 ( i − 1 ) (i-1) (i−1)个 K , V K,V K,V的值 p r e v i o u s _ K , p r e v i o u s _ V previous\_K,previous\_V previous_K,previous_V,然后在得到当前 i n d e x = i index=i index=i的 n e w _ k , n e w _ v new\_k,new\_v new_k,new_v之后,
更新 p r e v i o u s _ K = c o n c a t ( p r e v i o u s _ k , n e w _ k ) previous\_K=concat(previous\_k,new\_k) previous_K=concat(previous_k,new_k),
更新 p r e v i o u s _ V = c o n c a t ( p r e v i o u s _ v , n e w _ v ) previous\_V=concat(previous\_v,new\_v) previous_V=concat(previous_v,new_v);
这样做的好处是加快当前token的推理速度,缺点是增加了模型的显存占用。且随着max_length的增加,这部分的 K V c a c h e KV_{cache} KVcache会成为推理的显存占用瓶颈。
一、 Attention_function
这部分主要介绍的是self-attention中的q,K,V是怎么实现注意力机制的。
需要注意的是这里的 q q q为 m m m时刻的hidden表示,而 K , V K,V K,V为序列中所有时刻的hidden表示,因此维度为 m m m.
输入:单个时刻的输入 q q q,维度为 ( 1 , k ) (1,k) (1,k),前面m个时刻的 K K K,维度为 ( m , k ) (m,k) (m,k),前面 m m m时刻的 V V V,维度为