一、概念
最近DeepSeek概念的大火也使得许多在DeepSeek系列模型中用到的技术重新进入大家的视野中,Grouped-Query Attention就是其中一种(应用于模型DeepSeek LLM中)。GQA最初由Google在2023年发表的论文《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》中提出。本文详细介绍GQA的概念及原理。
译文:
多查询注意力(MQA)仅使用单个键值头,极大地加快了解码器推理速度。然而,MQA 可能会导致质量下降,而且仅仅为了更快的推理而训练一个单独的模型可能并不理想。我们(1)提出了一种将现有的多头语言模型检查点上训练为具有 MQA 的模型的方法,仅使用原始预训练计算量的 5%;(2)引入分组查询注意力(GQA),它是多查询注意力的一种泛化,使用中间数量(多于一个,少于查询头数量)的键值头。我们表明,经过上训练的 GQA 在速度与 MQA 相当的情况下,实现了接近多头注意力的质量。
二、核心原理
1、Uptraining
作者认为,从多头模型生成多查询模型只需要两个步骤:首先,转换检查点;其次,额外的预训练以允许模型适应其新结构。因此,MQA将键key和值value头的投影矩阵平均池化为单个投影矩阵,这比选择其中的单个键和值头或从头开始随机初始化新的键和值头效果更好。
2、Grouped-Query Attention
GQA将查询头分成 G 组,每个组共享一个键头和值头,GQA-G 是指带有 G 组的分组查询。举例来说,GQA-1 具有单个组,因此具有单个键头和值头,等价于 MQA;而 GQA-H,组等于头的数量,等价于 MHA。
当将多头检查点转换为 GQA 检查点时,作者通过合并该组中的所有原始头来构建每个组键和值头。中间数量的G组会导致一个插值模型,它比 MQA 质量更高,但比 MHA 更快。从 MHA 到 MQA 将 H 个键和值头减少为单个键和值头,从而减少键值缓存的大小,因此需要加载的数据量减少了 H 倍。然而,较大的模型通常会扩展头的数量,因此MQA代表了内存带宽和容量的更大程度削减。GQA 允许随着模型大小的增加保持带宽和容量的相同比例减少。此外,GQA 没有应用于编码器自关注层,因为编码器表示是并行计算的,内存带宽并不是这类模型的主要瓶颈。
三、python实现
这里我们直接把GitHub开源代码中实现的非官方torch版GQA代码po上来。
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from einops import einsum, rearrange
from torch import Tensor, nn
def scaled_dot_product_gqa(
query: Tensor,
key: Tensor,
value: Tensor,
dropout: float = 0.0,
scale: Optional[float] = None,
mask: Optional[Tensor] = None,
is_causal: Optional[bool] = None,
need_weights: bool = False,
average_attn_weights: bool = False,
force_grouped: bool = False,
):
"""Scaled dot product attention with support for grouped queries.
Einstein notation:
- b: batch size
- n / s: sequence length
- h: number of heads
- g: number of groups