GQA阅读


欢迎关注微信公众号InfiniReach,这里有更多AI大模型的前沿算法与工程优化方法分享
请添加图片描述

摘要

多query注意(MQA)仅使用单个key头,大大加快了解码器推理速度。然而,MQA可能导致质量下降,而且仅仅为了更快的推理而训练一个单独的模型可能是不可取的。

  1. 提出了一种方法,利用原始预训练计算量的5%将现有的多头语言模型checkpoint训练成具有MQA的模型;
  2. 引入分组查询注意(GQA),这是一种MQA的泛化,介于MHA和MQA之间,with single key and value heads per subgroup of query heads. 在一个query组里有一对key value。结果表明,改进后的GQA能以与MQA相当的速度获得接近多头注意力的质量。

MHA(Multi-head Attention)是标准的多头注意力机制,h个Query、Key 和 Value 矩阵。
MQA(Multi-Query Attention,Fast Transformer Decoding: One Write-Head is All You Need)是多查询注意力的一种变体,也是用于自回归解码的一种注意力机制。与MHA不同的是,MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。

self.Wqkv = nn.Linear(                           # 【关键】Multi-Query Attention 的创建方法
            d_model,
            d_model + 2 * self.head_dim,                 # 只创建 query 的 head 向量,所以只有 1 个 d_model
            device=device,                               # 而 key 和 value 则只共享各自的一个 head_dim 的向量
        )

方法

后续训练

从MHA生成MQA分为两个步骤:

  1. 首先,转换checkpoint,图1显示了将MHA checkpoint转换为MQA checkpoint的过程。key头和value头的投影矩阵被平均池化为单个投影矩阵,我们发现这比选择单个key头和value头或从头随机初始化新的键头和值头效果更好。
    在这里插入图片描述
  2. 其次,进行额外的预训练,以允许模型适应其新结构。
    然后在相同的预训练上对转换后的checkpoint进行原始训练步骤的小比例α的预训练。

Grouped-query attention

分组查询注意GQA将query头分为G组,每组共享一个键头和值头。

GQA-G是指具有G组的分组查询。GQA-1具有单个组,因此具有单个键值头,相当于MQA,而GQA-H具有与头数量相等的组,相当于MHA。
在这里插入图片描述

图2显示了分组查询注意和多头/多查询注意的比较。当将多头checkpoint转换为GQA checkpoint时,我们通过平均池化该组中的所有原始头来构建每个组键和值头。

中间数量的组导致插值模型比MQA质量更高,但比MHA速度更快,并且,正如我们将展示的那样,代表了有利的权衡。从MHA到MQA将H键和值头减少到单个键和值头,减少了键-值缓存的大小,因此需要加载的数据量减少了H倍。然而,更大的模型通常会扩展头的数量,这样MQA在内存带宽和容量方面都代表了更大的削减。GQA允许我们在模型大小增加时保持相同比例的带宽和容量减少。

此外,由于KV-cache随模型尺寸的增大而增大,而模型FLOPs和参数随模型尺寸的平方而增大,较大的模型受到的注意力带来的内存带宽开销相对较小。最后,大型模型的standard sharding通过模型分区的数量复制单个键和值头(Pope et al., 2022);GQA消除了这种分区的浪费。因此,我们期望GQA能够为更大的模型提供一个特别好的权衡。

实验

所有模型都基于T5.1.1架构(rafael等人,2020),由JAX (Bradbury等人,2018)、flex (Heek等人,2020)和Flaxformer1实现。对于我们的主要实验,我们考虑了具有多头注意力的T5 Large和XXL,以及具有多查询和分组查询注意力的T5 XXL的升级训练版本。我们将MQA和GQA应用于解码器自注意和交叉注意,但不应用于编码器自注意。

从公共T5.1.1检查点初始化经过升级训练的模型。将键头和值头平均池到适当的MQA或GQA结构中,然后使用原始预训练设置对原始预训练步骤进行进一步的α比例预训练(rafael et al., 2020)。α = 0.05

我们评估了CNN/Daily Mail (Nallapati等人,2016)、arXiv和PubMed (Cohan等人,2018)、MediaSum (Zhu等人,2021)和Multi-News (Fabbri等人,2019)的摘要数据集;翻译数据集WMT 2014;和问答数据集TriviaQA (Joshi et al., 2017)。我们没有对GLUE等流行的分类基准进行评估(Wang et al., 2019),因为自回归推理不太适用于这些任务。

对于微调,我们使用恒定的学习率为0.001,批大小为128,所有任务的dropout为0.1。CNN/Daily Mail和WMT使用的输入长度为512,输出长度为256。其他摘要数据集使用输入长度2048和输出长度512。最后,TriviaQA使用输入长度2048和输出长度32。我们训练直到收敛,并选择最高性能的checkpoint。我们使用贪婪解码进行推理。
在这里插入图片描述

下图提供实验来研究不同建模选择的影响。我们评估了具有代表性的任务子样本的性能:CNN/Daily Mail,(简短摘要),MultiNews(长格式摘要)和TriviaQA(问答)。
在这里插入图片描述
不同检查点转换方法对T5-Large上训练到比例为α = 0.05的MQA的性能比较。“Mean”是指将键和值头 做mean-pool,“First”选择第一个头,“Random”从头开始初始化头。

图6展示了GQA组的数量对推断速度的影响。对于较大的模型,KV缓存的内存带宽开销约束较少(Shazeer, 2019),而由于头数量的增加,键值大小的减少更为明显。因此,增加来自MQA的组的数量最初只会导致适度的减速,随着我们接近MHA而增加成本。我们选择了8组作为有利的中间立场。
在这里插入图片描述
GQA- xxl的每个样本时间作为输入长度2048和输出长度512的GQA组数量的函数。从1个(MQA)组增加到8个组会增加适度的推理开销,增加更多组会增加成本

Llama2

为了创建新的Llama 2模型家族,我们从Touvron等人(2023)中描述的预训练方法开始,使用优化的自回归变压器,但进行了一些更改以提高性能。具体来说,我们执行了更健壮的数据清理,更新了我们的数据混合,训练了超过40%的总令牌,将上下文长度增加了一倍,并使用分组查询关注(GQA)来提高更大模型的推理可伸缩性。表1比较了新Llama 2模型与Llama 1模型的属性。

我们采用了Llama 1中的大部分预训练设置和模型架构。我们使用标准变压器架构(Vaswani等人,2017),使用RMSNorm应用预归一化(Zhang和Sennrich, 2019),使用SwiGLU激活函数(Shazeer, 2020)和旋转位置嵌入(RoPE, Su等人,2022)。与Llama 1的主要架构差异包括增加了上下文长度和分组查询关注(GQA)。我们在附录A.2.1节中通过烧蚀实验详细说明这些差异的重要性。

在这里插入图片描述
更大的模型——34B和70B——使用分组查询注意(GQA)来提高推理的可扩展性

GQA

自回归解码的标准做法是缓存序列中前面标记的键(K)和值(V)对,从而加快注意力计算。然而,随着上下文窗口或批处理大小的增加,多头注意(MHA)模型中与KV缓存大小相关的内存成本显著增加。对于较大的模型,KV缓存大小成为瓶颈,键和值预测可以在多个头之间共享,而不会导致性能下降(Chowdhery et al., 2022)。可以使用具有单个KV投影的原始多查询格式(MQA, Shazeer, 2019)或具有8 KV投影的分组查询关注变体(GQA, Ainslie等人,2023)。

在表18中,我们用MHA基线比较了MQA和GQA变体。我们用150B令牌训练所有模型,同时保持固定的30B模型大小。为了在GQA和MQA之间保持相似的总体参数计数,我们增加了前馈层的维度,以补偿注意层的减少。对于MQA变体,我们将FFN维度增加1.33倍,对于GQA变体,我们将其增加1.3倍。从结果中,我们观察到GQA变体在大多数评估任务上的表现与MHA基线相当,并且平均优于MQA变体。
在这里插入图片描述

为了优化延迟,我们在具有张量并行性的单个节点上使用8个a100来托管最大的模型(Shoeybi等人,2019)。在这种设置中,由于头的数量低于gpu的数量,因此不能再跨头进行MQA分片。要么在所有gpu中复制KV值(使KV缓存大小等于GQA),要么在批处理维度上进行分片(Pope et al., 2022)。然而,后者可能会使推理服务复杂化,因为它只有在批处理大小大于分片数量时才有效,并且在所有情况下,额外的通信成本都不值得。

因此,基于消融结果和缩放推理的易用性,对于34B和70B Llama 2模型,我们选择使用GQA而不是MQA。图24显示了与MHA基线相比,30B GQA和MQA消融模型的推理速度变化情况,使用8 x 80 GiB a100进行张量并行实验。在这些运行中,我们只是在所有gpu中复制MQA的KV头,因此MQA的KV缓存大小与GQA相等,并且这两个变体的行为非常相似(MQA只是具有稍微大一点的FFN维度)。

在这里插入图片描述
Multi-query variants可以在更大的批处理大小下实现更高的吞吐量,并且在较小的批处理上显示类似的延迟。输出长度固定为128个令牌。第一个数据点对应于批处理大小1,然后我们将其加倍,直到模型耗尽内存。对于256个令牌的上下文,MHA变体在批大小为1024时触发内存不足错误,而对于2k个上下文,批大小为128时触发内存不足错误,而MQA和GQA在这些设置中可以成功运行。

flash attention2

多查询注意(Multi-query attention, MQA)[15]和分组查询注意(group -query attention, GQA)[1]是注意的变体,其中多个查询头关注同一个键和值的头,以便在推理过程中减少KV缓存的大小。我们不必为计算复制键头和值头,而是隐式地操作头的索引来执行相同的计算。在逆向传递中,我们需要对不同头部的梯度dK和dV求和,这些头部是隐式重复的。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在openeuler上安装gqa,可以按照以下步骤进行操作: 1. 首先,确保你已经安装了openeuler操作系统。可以从官方网站下载并按照指示进行安装。 2. 接下来,打开终端并使用管理员权限登录。你可以使用su命令切换到root用户。 3. 然后,使用以下命令安装gqa所需的依赖项: ``` dnf install python3 python3-pip -y ``` 4. 安装pip工具后,可以使用以下命令安装gqa: ``` pip3 install gqa ``` 5. 安装完成后,你可以使用以下命令验证gqa是否成功安装: ``` gqa --help ``` 如果成功安装,你将看到gqa的帮助信息。 总结:要在openeuler上安装gqa,你需要首先安装openeuler操作系统,然后使用终端和管理员权限登录,安装必要的依赖项,最后使用pip工具安装gqa。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [Asp.net面试题](https://blog.csdn.net/weixin_34082789/article/details/94306699)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *3* [插入记录时单引号的处理](https://blog.csdn.net/weixin_42400383/article/details/114219027)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值