Multi-Head Attention和Multi-Query Attention的计算分析

本文主要是阅读论文《Fast Transformer Decoding: One Write-Head is All
You Need
》的学习记录,这是一篇2019年的改善Multi-Head Attention带来的显存占用瓶颈问题的一种解决方案。

本文根据论文的撰写顺序,从Multi-Head Attention的代码开始,分析MHA在训练阶段和推理阶段的计算量和显存占用参数量;然后给出Multi-Query Attention的代码,分析MQA在训练阶段和推理阶段的计算量和显存占用情况。

提前预备:

  1. 假设读者了解decoder解码器的结构,包括L层Transformer,每层Transformer由一个self-attention层和两个MLP层构成,self-attention层默认为采用H头的注意力层。
  2. 了解在训练过程中,是通过mask来实现并行训练的。如果对这部分不是很熟悉,可以去看一下相关的知乎或者博客。
  3. 了解在推理过程中,模型的输入是当前时刻m的word embedding输入,输出是当前时刻的输出。在中间计算过程中需要用到过去所有时刻的输入来计算Key和Value,实现注意力机制,因此max_length越长的模型,推理的时间开销越大,因为模型需要一次次的重新计算key和value张量。因此在实现中往往采用KVcache的方式来进行加速。

关于KVcache的概念和计算可以参考知乎的文章《分析transformer模型的参数量、计算量、中间激活、KV cache
在这里为了加快模型的推理速度,网络的输入为当前时刻 i n d e x = i index=i index=i的word embedding,模型会在self-attention层保存前面 ( i − 1 ) (i-1) (i1) K , V K,V K,V的值 p r e v i o u s _ K , p r e v i o u s _ V previous\_K,previous\_V previous_K,previous_V,然后在得到当前 i n d e x = i index=i index=i n e w _ k , n e w _ v new\_k,new\_v new_k,new_v之后,
更新 p r e v i o u s _ K = c o n c a t ( p r e v i o u s _ k , n e w _ k ) previous\_K=concat(previous\_k,new\_k) previous_K=concat(previous_k,new_k),
更新 p r e v i o u s _ V = c o n c a t ( p r e v i o u s _ v , n e w _ v ) previous\_V=concat(previous\_v,new\_v) previous_V=concat(previous_v,new_v);
这样做的好处是加快当前token的推理速度,缺点是增加了模型的显存占用。且随着max_length的增加,这部分的 K V c a c h e KV_{cache} KVcache会成为推理的显存占用瓶颈。

一、 Attention_function

这部分主要介绍的是self-attention中的q,K,V是怎么实现注意力机制的。
需要注意的是这里的 q q q m m m时刻的hidden表示,而 K , V K,V K,V为序列中所有时刻的hidden表示,因此维度为 m m m.
输入:单个时刻的输入 q q q,维度为 ( 1 , k ) (1,k) (1,k),前面m个时刻的 K K K,维度为 ( m , k ) (m,k) (m,k),前面 m m m时刻的 V V V,维度为

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值