Multi-Head Attention和Multi-Query Attention的计算分析

本文主要是阅读论文《Fast Transformer Decoding: One Write-Head is All
You Need
》的学习记录,这是一篇2019年的改善Multi-Head Attention带来的显存占用瓶颈问题的一种解决方案。

本文根据论文的撰写顺序,从Multi-Head Attention的代码开始,分析MHA在训练阶段和推理阶段的计算量和显存占用参数量;然后给出Multi-Query Attention的代码,分析MQA在训练阶段和推理阶段的计算量和显存占用情况。

提前预备:

  1. 假设读者了解decoder解码器的结构,包括L层Transformer,每层Transformer由一个self-attention层和两个MLP层构成,self-attention层默认为采用H头的注意力层。
  2. 了解在训练过程中,是通过mask来实现并行训练的。如果对这部分不是很熟悉,可以去看一下相关的知乎或者博客。
  3. 了解在推理过程中,模型的输入是当前时刻m的word embedding输入,输出是当前时刻的输出。在中间计算过程中需要用到过去所有时刻的输入来计算Key和Value,实现注意力机制,因此max_length越长的模型,推理的时间开销越大,因为模型需要一次次的重新计算key和value张量。因此在实现中往往采用KVcache的方式来进行加速。

关于KVcache的概念和计算可以参考知乎的文章《分析transformer模型的参数量、计算量、中间激活、KV cache
在这里为了加快模型的推理速度,网络的输入为当前时刻 i n d e x = i index=i index=i的word embedding,模型会在self-attention层保存前面 ( i − 1 ) (i-1) (i1) K , V K,V K,V的值 p r e v i o u s _ K , p r e v i o u s _ V previous\_K,previous\_V previous_K,previous_V,然后在得到当前 i n d e x = i index=i index=i n e w _ k , n e w _ v new\_k,new\_v new_k,new_v之后,
更新 p r e v i o u s _ K = c o n c a t ( p r e v i o u s _ k , n e w _ k ) previous\_K=concat(previous\_k,new\_k) previous_K=concat(previous_k,new_k),
更新 p r e v i o u s _ V = c o n c a t ( p r e v i o u s _ v , n e w _ v ) previous\_V=concat(previous\_v,new\_v) previous_V=concat(previous_v,new_v);
这样做的好处是加快当前token的推理速度,缺点是增加了模型的显存占用。且随着max_length的增加,这部分的 K V c a c h e KV_{cache} KVcache会成为推理的显存占用瓶颈。

一、 Attention_function

这部分主要介绍的是self-attention中的q,K,V是怎么实现注意力机制的。
需要注意的是这里的 q q q m m m时刻的hidden表示,而 K , V K,V K,V为序列中所有时刻的hidden表示,因此维度为 m m m.
输入:单个时刻的输入 q q q,维度为 ( 1 , k ) (1,k) (1,k),前面m个时刻的 K K K,维度为 ( m , k ) (m,k) (m,k),前面 m m m时刻的 V V V,维度为 ( m , v ) (m,v) (m,v)
输出:根据自注意力机制得到的输出 y y y
在这里插入图片描述

首先计算 q q q m m m个key的相似度,得到一个 m m m维的权重,然后将权重进行sigmoid标准化。这样得到的权重的每一个 W i W_i Wi都表示q和第 i i i个Key的相似度或者说相关度,其值越大表示越相关。
将这个权重和对应的 v a l u e value value相乘,也就是赋予对应位置的 V i V_i Vi更高的注意力权重。最后的返回值就是这个将带权重的Value叠加作用到最终结果的一个过程。

上面这个过程就是一个简单的attention过程。

二、Multi-head Attention

从上述简单的attention过程,引入Multihead的概念。

  • 输入 x x x:张量维度为 d d d
  • M M M为所有 m m m时刻的输入,维度为 [ m , d ] [m,d] [m,d]
  • P _ q P\_q P_q实现的功能为:1. 提供一个线性映射转换,2. 转换后的 x x x reshape成 h h h个head,用于表征Multi-head这个概念。
  • P _ k P\_k P_k, P _ v P\_v P_v的功能为:1. 为 m m m个时刻的输入向量 M M M提供一个线性映射变换,2. 将转换之后的 k e y key key, v a l u e value value reshape成Multi-head
  • h h h表示head的数量, k k k为每个head中张量的维度,一般的实现中, h × k = d h \times k=d h×k=d,而且 k = v k=v k=v,也就是key和value的最后一个维度是一致的。

在论文的表示中,大写的 K V KV KV表示张量名称,小写的 k 、 v k、v kv表示长度。需要注意不要搞混了。
在这里插入图片描述

可以看到Multi-head Attention和上面的最简单形式的attention相比:

  1. 多了 q K V qKV qKV的计算方式,不再是直接存在 q , K , V q,K,V q,K,V,而是给出了 K V KV KV是通过 m m m个时刻的输入和投影矩阵 P _ k , P _ v P\_k,P\_v P_k,P_v计算得到, q q q是通过输入 x x x P _ q P\_q P_q计算得到的。
  2. 多了head的概念,将当前时刻的输入 x x x k e y key key v a l u e value value的计算结果都reshape成h个集合。
  3. 输出也不是简单的通过value的权重和得到,而是需要增加一个输出映射 P o P_o Po

三、训练过程中的Multi-Head-attention

训练过程和上面这个过程的区别在于

  1. 训练过程多了batch这个维度,可以同时计算不同序列的输出。提高了显存利用率。
  2. 训练过程有GT,知道每个时刻的输入和输出,因此可以通过mask方式来实现并行计算,可以计算同一个序列上不同位置的输出。而不需要向推理那样等待前一个时刻的输出,来作为后一时刻的输入。

符号说明

  • m m m表示的整个序列的输入,通过mask来实现遮蔽指定位置之后的其他输入。
  • 这里的 n n n表示 n n n个时刻的输入(因为原因2,所以可以并行前向推理),也就是可以通过一次前向过程计算 n n n个时刻的输出。一般实现的时候, n = m n=m n=m。这里的 n n n一般为序列的长度,也就是会在依次前向过程中预测整个序列的每一个元素。通过 m a s k [ n , m ] mask[n,m] mask[n,m]来实现mask掉后面的输入。具体来说,假设 m = n = 10 m=n=10 m=n=10,那么对于时刻 i = 4 i=4 i=4来说,他的 m a s k = [ 1 , 1 , 1 , 1 , − i n f , − i n f , − i n f , . . . ] mask=[1,1,1,1,-inf,-inf,-inf,...] mask=[1,1,1,1,inf,inf,inf,...],这样在计算KV计算只使用该时刻之前的输入,而不使用后面时刻的输入。

其余的和上面的基本没有什么变化
在这里插入图片描述
分析一下计算量,建立在

  1. b = b a t c h s i z e b=batchsize b=batchsize,
  2. n = m = n=m= n=m=序列的长度,
  3. h ∗ k = d h*k=d hk=d
  4. k = v k=v k=v

这四条实践经验的基础上,我们知道矩阵乘法的计算量为最终输出的元素个数乘以每个元素的计算量。 我们以矩阵一次乘-加运算作为计算基础。

  • Q K V QKV QKV的计算量都是 b n d 2 bnd^2 bnd2,那么总计算量为 3 b n d 2 3bnd^2 3bnd2
  • logits的计算量为 b h n 2 k = b n 2 d bhn^2k=bn^2d bhn2k=bn2d
  • weights的计算量为 (忽略一下)
  • O计算量为 b h n v ∗ n = b h n 2 v = b n 2 d bhnv*n=bhn^2v=bn^2d bhnvn=bhn2v=bn2d
  • Y的计算量 b n d ∗ d = b n d 2 bnd*d=bnd^2 bndd=bnd2

综合来看,计算复杂度为 θ ( b n d 2 ) \theta (bnd^2) θ(bnd2),考虑到 n ⩽ d n \leqslant d nd的情况:

计算参数量

  • 1. X , M , Q , K , V , O , Y X,M,Q,K,V,O,Y X,M,Q,K,V,O,Y的参数都是bnd,那么总参数量为 6 b n d 6bnd 6bnd,
  • 2.logits和weights总参数量为 2 b h n 2 2bhn^2 2bhn2,
  • 3.参数 P _ v , P _ q , P _ k , P _ o P\_v,P\_q,P\_k,P\_o P_v,P_q,P_k,P_o的参数总量为 4 d 2 4d^2 4d2.

考虑到1的部分元素是中间激活值,不会保存,2的全部元素是中间激活值,不会保存。那么空间复杂度可以表示为O(bnd+d^2)

四、推理过程的MHA

这里着重讨论的是生成式模型,因为生成式模型的在某个时刻 i i i的输出依赖于从0时刻到当前时刻的所有输入。且在当前时刻的输出得到之前,无法计算后面时刻的输出,因此不像训练阶段那样可以并行执行前向过程。

但是前面时刻的输入主要是为了计算 K V KV KV q q q只和当前时刻的输入有关。因此想要加速推理,可以将 K V KV KV在推理过程中进行缓存保留。这也就是 K V c a c h e KVcache KVcache,这种方式存在的问题是 K V c a c h e KVcache KVcache是需要占用显存的。
在这里插入图片描述
这个公式和上面训练阶段的公式的区别在于 K V KV KV的计算.

K V KV KV是通过之前时刻的 p r e v _ K prev\_K prev_K和当前时刻根据当前输入 M M M(我觉得这里的 M M M X X X表达的是一个意思,都是当前时刻的输入,为了和前面一章节的公式保持一致所以才这么写的。但是不确定我的这种理解对不对,欢迎讨论或者告知这里的 M M M是什么)计算得到的 n e w _ K new\_K new_K, n e w _ V new\_V new_V.

类似上一章节的计算复杂度推理,这里的考虑到不能并行计算的情况,那么对于长度为 n n n的序列,需要计算 n n n次,那么计算复杂度为 θ ( b n d 2 ) \theta (bnd^2) θ(bnd2)

空间复杂度也是同样的,在一次计算中,不考虑中间激活值,主要保存是是 K V c a c h e KVcache KVcache,以及矩阵 P _ q , P _ k , P _ v , P _ o P\_q,P\_k,P\_v,P\_o P_q,P_k,P_v,P_o.前者的参数量为 O ( b n 2 d ) O(bn^2d) O(bn2d),后者的参数量为 O ( d 2 ) O(d^2) O(d2)。可以看到随着序列长度 n n n的增加或者隐藏层深度 d d d的增加,这部分的显存占用将会称为瓶颈。

五、Multi-Query Attention

介绍完MHA之后,就可以介绍MQA了。
在这里插入图片描述

左边是MHA,右边是MQA,中间是没有提到过的GQA(Group-Query Attention,可以理解为不那么激进的MQA)。

可以看到MHA中每个head独立维护自己的 K V c a c h e KVcache KVcache。在MQA中则是同一层Transformer中的所有head维护一个 K V c a c h e KVcache KVcache,那么 K V c a c h e KVcache KVcache的显存可以明显的减少 h h h倍。

另外KV的投影映射矩阵 P _ k , p _ v P\_k,p\_v P_k,p_v也会同样的减少 h h h倍。
在这里插入图片描述
这里好像就没有什么好讲的啦~~和MHA的第三章节唯一的不同就是算法前面提到的两点。其他都是一样一样的

六、 推理阶段的MQA

在这里插入图片描述
x x x是某个时刻的输入,维度为 [ b , d ] [b,d] [b,d]
K V c a c h e KVcache KVcache的更新不再独立于每个head,而是所有head共享,因此 P k , P v , p r e v _ K , p r e v _ V P_k,P_v,prev\_K,prev\_V Pk,Pv,prev_K,prev_V的维度都少了一维。

七、实验对比

作者在英文转德文的数据集上进行了实验,采用的模型是具有2.1亿参数、6层layers,d_model=1024,head=8,且位置编码可学习的模型。
在对比MHA和MQA的时候,为了保持总参数量不变,会将MQA中的MLP层的hidden layer从1024*4增大为5440.

训练的效果评估指标如下所示,可以看到相比于MHA有轻微的性能衰减。PPL指标越小越好,BLEU是用于衡量翻译质量的数据集,越大越好。
在这里插入图片描述

从训练时间和推理时间上看,训练时有轻微的减少。对于单个training step从MHA的433ms减少到MQA的425ms。这里的训练上并没有很明显的提升。

在推理上的提升是很明显的。实验是在一个英文转德文的数据集上进行的,encoder的长度为128,decoder的长度也是128。encoder模块推理需要222ms,在MHA中,decoder需要47ms。在MQA上,encoder的时间为195ms,decoder只需要3.9ms
下表中时间是每个token上的时间。所以和上面我们提到的时间看起来不同,实际是一样的。
在这里插入图片描述

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
### 回答1: 多头注意力代码(Multi-Head Attention Code)是一种用于自然语言处理的机器学习技术,它可以帮助模型同时从多个表征空间中提取信息,从而提高模型的准确性。它的主要作用是通过使用多头的注意力机制,来计算输入的表征空间之间的相似性,从而使模型更加准确。 ### 回答2: multi-head attention是一种用于处理序列数据中的深度学习模型。它通过并行地学习多个注意力头,可以捕获不同远距离依赖关系和注意力机制在不同空间维度上的变换。下面是描述一个基本的multi-head attention的代码。 首先,我们需要引入所需的Python库,包括numpy和torch: ```python import numpy as np import torch import torch.nn as nn import torch.nn.functional as F ``` 接下来,我们定义一个MultiHeadAttention类,继承自nn.Module类,以便在PyTorch中构建模型: ```python class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() self.num_heads = num_heads self.d_model = d_model self.query_fc = nn.Linear(d_model, d_model) self.key_fc = nn.Linear(d_model, d_model) self.value_fc = nn.Linear(d_model, d_model) self.fc = nn.Linear(d_model, d_model) def forward(self, query, key, value): batch_size = query.size(0) # 通过线性变换获得query、key和value query = self.query_fc(query) key = self.key_fc(key) value = self.value_fc(value) # 将输入的query、key和value分割为不同的注意力头 query = query.view(batch_size * self.num_heads, -1, self.d_model // self.num_heads) key = key.view(batch_size * self.num_heads, -1, self.d_model // self.num_heads) value = value.view(batch_size * self.num_heads, -1, self.d_model // self.num_heads) # 计算注意力得分 scores = torch.bmm(query, key.transpose(1, 2)) scores = scores / np.sqrt(self.d_model // self.num_heads) attn_weights = F.softmax(scores, dim=-1) # 使用注意力得分加权计算value output = torch.bmm(attn_weights, value) # 将分割的注意力头拼接起来 output = output.view(batch_size, -1, self.d_model) # 通过线性变换得到最终的输出 output = self.fc(output) return output ``` 在上面的代码中,我们首先定义了MultiHeadAttention类的初始化方法,在这个方法中,我们传入注意力头的数量num_heads和输入维度d_model。然后,我们定义了query、key和value的线性变换层。在forward方法中,我们首先通过线性变换得到query、key和value,然后将它们分成不同的注意力头。接下来,我们计算注意力得分,并使用注意力得分加权计算value。最后,我们将分割的注意力头拼接起来,并通过线性变换得到最终的输出。 以上就是一个基本的multi-head attention的代码实现。在实际使用中,我们可以根据需求对其进行修改和扩展。 ### 回答3: multi-head attention是一种用于自然语言处理的注意力机制,用于对输入序列进行加权表示。在代码实现中,multi-head attention可以分为以下几个步骤: 1. 首先,需要定义输入序列x和相关的参数,如隐藏层大小和注意力头数。 2. 然后,将输入序列通过线性变换得到q、k和v矩阵,即对q、k、v分别乘以权重矩阵Wq、Wk和Wv。 3. 接下来,将q、k和v矩阵分别切分成多个头,即将q、k、v矩阵按行分成n个头。 4. 对于每个头,计算注意力权重。首先,计算q和k的点乘,然后除以一个可调节的缩放因子根号d,其中d为隐藏层大小。将结果通过softmax函数得到注意力权重。 5. 将注意力权重与v矩阵相乘,得到每个头的加权表示。 6. 将每个头的加权表示拼接起来,得到最终的加权表示。 7. 最后,通过线性变换将加权表示映射回原始的隐藏层大小。 以上就是multi-head attention的代码实现过程,通过这个过程可以对输入序列进行加权表示,从而提取关键信息。每个头的注意力权重计算可以独立进行,可以并行计算,提高了计算效率。multi-head attention在自然语言处理中应用广泛,如机器翻译、文本摘要等任务中都取得了很好的效果。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值