note
- MQA、GQA、MLA本质都是在围绕“如何减少kv cache同时尽可能保证效果”进行优化的产物。
- 从Layer的视角来看,MQA/GQA可以认为是Intra-Layer KV Cache Shared(层内KV Cache共享),而YOCO提出的想法,则可以认为是Inter-Layer KV Cache Shared(层间KV Cache共享)。层间KV Cache共享,理论上的可以最多把KV Cache的Memory需求降低到1/N(N为Transformer层数),并且,这和层内KV Cache共享的技术,比如MQA和GQA是不冲突的,两者可以一起使用,从而极大地降低KV Cache的显存开销。
文章目录
回顾:多头注意力
比如在pytorch中我们可以很方便的使用nn.TransformerEncoderLayer
或者nn.TransformerDecoderLayer
类,里面包括一个多头注意力和一个FFN前馈神经网络(这两部分之间有残差连接)和层归一化操作。
from torch import nn
# 编码层:使用Transformer
encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
可以看到pytorch中nn.TransformerEncoderLayer
的源码,就有多头注意力MultiheadAttention
:
class TransformerEncoderLayer(Module):
__constants__ = ['batch_first', 'norm_first']
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
**factory_kwargs)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
activation = _get_activation_fn(activation)
# We can't test self.activation in forward() in TorchScript,
# so stash some information about it instead.
if activation is F.relu or isinstance(activation, torch.nn.ReLU):
self.activation_relu_or_gelu = 1
elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
self.activation_relu_or_gelu = 2
else:
self.activation_relu_or_gelu = 0
self.activation = activation
def __setstate__(self, state):
super().__setstate__(state)
if not hasattr(self, 'activation'):
self.activation = F.relu
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
is_causal: bool = False) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
is_causal: If specified, applies a causal mask as src_mask.
Default: ``False``.
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
src_key_padding_mask = F._canonical_mask(
mask=src_key_padding_mask,
mask_name="src_key_padding_mask",
other_type=F._none_or_dtype(src_mask),
other_name="src_mask",
target_type=src.dtype
)
src_mask = F._canonical_mask(
mask=src_mask,
mask_name="src_mask",
other_type=None,
other_name="",
target_type=src.dtype,
check_other=False,
)
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
why_not_sparsity_fast_path = ''
if not src.dim() == 3:
why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
elif self.training:
why_not_sparsity_fast_path = "training is enabled"
elif not self.self_attn.batch_first :
why_not_sparsity_fast_path = "self_attn.batch_first was not True"
elif not self.self_attn._qkv_same_embed_dim :
why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
elif not self.activation_relu_or_gelu:
why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
elif not (self.norm1.eps == self.norm2.eps):
why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
elif self.self_attn.num_heads % 2 == 1:
why_not_sparsity_fast_path = "num_head is odd"
elif torch.is_autocast_enabled():
why_not_sparsity_fast_path = "autocast is enabled"
if not why_not_sparsity_fast_path:
tensor_args = (
src,
self.self_attn.in_proj_weight,
self.self_attn.in_proj_bias,
self.self_attn.out_proj.weight,
self.self_attn.out_proj.bias,
self.norm1.weight,
self.norm1.bias,
self.norm2.weight,
self.norm2.bias,
self.linear1.weight,
self.linear1.bias,
self.linear2.weight,
self.linear2.bias,
)
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
elif not all((x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args):
why_not_sparsity_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
if not why_not_sparsity_fast_path:
merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)
return torch._transformer_encoder_layer_fwd(
src,
self.self_attn.embed_dim,
self.self_attn.num_heads,
self.self_attn.in_proj_weight,
self.self_attn.in_proj_bias,
self.self_attn.out_proj.weight,
self.self_attn.out_proj.bias,
self.activation_relu_or_gelu == 2,
self.norm_first,
self.norm1.eps,
self.norm1.weight,
self.norm1.bias,
self.norm2.weight,
self.norm2.bias,
self.linear1.weight,
self.linear1.bias,
self.linear2.weight,
self.linear2.bias,
merged_mask,
mask_type,
)
x = src
if self.norm_first:
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
x = self.norm2(x + self._ff_block(x))
return x
# self-attention block
def _sa_block(self, x: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False, is_causal=is_causal)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
1. 编码器内的自注意力机制
可以看到上面编码器中有自注意机制函数,通常QKV都来自同一个序列,即为序列中的每个token生成Q、K、V矩阵(是由同一个X输入进经过三个不同的线性变化得到的,对应的线性变换矩阵W_q
、W_K
、W_v
是待学习的权重矩阵),比如“i love large language model”。
# self-attention block
def _sa_block(self, x: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False, is_causal=is_causal)[0]
return self.dropout1(x)
- 为了得到编码单词 x i x_i xi 时所需要关注的上下文信息,通过位置 i \mathrm{i} i 查询向量与其他位置的键向量做点积得到匹配分数 q i ⋅ k 1 , q i ⋅ k 2 , … , q i ⋅ k t \boldsymbol{q}_i \cdot \boldsymbol{k}_1, \boldsymbol{q}_i \cdot \boldsymbol{k}_2, \ldots, \boldsymbol{q}_i \cdot \boldsymbol{k}_t qi⋅k1,qi⋅k2,…,qi⋅kt
- 为了防止过大的匹配分数在后续Softmax 计算过程中导致的梯度爆炸及收敛效率差的问题,这些得分会除以放缩因子 d \sqrt{d} d 以稳定优化
- 放缩后的得分经过Softmax 归一化为概率(即一个权重矩阵),与其他位置的值向量相乘来聚合希望关注的上下文信息,并最小化不相关信息的干扰。上述计算过程如下,其实就是对各个value进行加权求和:
Z = Attention ( Q , K , V ) = Softmax ( Q K ⊤ d ) V \boldsymbol{Z}=\operatorname{Attention}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V})=\operatorname{Softmax}\left(\frac{\boldsymbol{Q} \boldsymbol{K}^{\top}}{\sqrt{d}}\right) \boldsymbol{V} Z=Attention(Q,K,V)=Softmax(dQK⊤)V
2. 解码器内的交叉注意力机制
在解码器中有self-attention和cross-attention模块:
- 前者:qkv都是解码器到此为止的输出,和编码器的类似
- 后者:query来自于解码器的当前输出,用编码器的输出作为key和value,使得解码器在生成每一个输出时都能考虑输入序列的信息。
class TransformerDecoderLayer(Module):
__constants__ = ['batch_first', 'norm_first']
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
**factory_kwargs)
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
**factory_kwargs)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.dropout3 = Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
else:
self.activation = activation
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super().__setstate__(state)
def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
tgt_is_causal: bool = False,
memory_is_causal: bool = False,
) -> Tensor:
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
x = tgt
if self.norm_first:
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
x = x + self._ff_block(self.norm3(x))
else:
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
x = self.norm3(x + self._ff_block(x))
return x
# self-attention block
def _sa_block(self, x: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
need_weights=False)[0]
return self.dropout1(x)
# multihead attention block
def _mha_block(self, x: Tensor, mem: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
x = self.multihead_attn(x, mem, mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
need_weights=False)[0]
return self.dropout2(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout3(x)
从上面的源码中可以看到forward
部分经过自注意力_sa_block
的计算后,解码器会使用交叉注意力self._mha_block
,这函数的第一个参数self.norm2(x)
是目标序列的嵌入表示(作为Q),第二个参数memory
是编码器的输出(作为K和V),从而让模型理解输入序列和目标序列之间的依赖关系。
x = tgt
if self.norm_first:
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
x = x + self._ff_block(self.norm3(x))
else:
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
x = self.norm3(x + self._ff_block(x))
return x
Note
- LLama2的注意力机制使用了GQA。三种机制的图如下:
MHA机制(Multi-head Attention)
MHA(Multi-head Attention)是标准的多头注意力机制,包含h个Query、Key 和 Value 矩阵。所有注意力头的 Key 和 Value 矩阵权重不共享。即多个独立的单头注意力的拼接,
在MHA中,KV Heads的数量和Query Heads的数量相同,每个Query Head持有一个独立的KV Head,在Attention中,对单独的KV Head做计算。但是,当模型层数加深和Heads数变多后,QKV Attention的计算和IO都会快速增加。为了缓解这种情况,有学者提出了MQA和GQA。
MQA机制(Multi-Query Attention)
MQA(Multi-Query Attention,Fast Transformer Decoding: One Write-Head is All You Need)是多查询注意力的一种变体,也是用于自回归解码的一种注意力机制。MQA比较极端,只保留一个KV Head,多个Query Heads共享相同的KV Head。这相当于不同Head的Attention差异,全部都放在了Query上,需要模型仅从不同的Query Heads上就能够关注到输入hidden states不同方面的信息。这样做的好处是,极大地降低了KV Cache的需求,但是会导致模型效果有所下降。
GQA机制(Grouped-Query Attention)
GQA(Grouped-Query Attention,GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints)是分组查询注意力。
GQA (Group Queries Attention): GQA与MQA不同,而是采取了折中的做法。GQA把Query Heads进行分组,每组Query Heads对应一个KV Head。比如,把8个Query Heads分成4组,每个Grouped Query Head包含2个Query Heads,一个Grouped Query Head对应一个KV Head,此时总共有4个KV Heads。GQA可以在减少计算量和KV Cache同时确保模型效果不受到大的影响。
在目前大部分主流训推框架或算法,都已经支持MQA/GQA,比如FlashAttention中,也支持MQA和GQA。对于MQA和GQA的情形,FlashAttention采用Indexing的方式,而不是直接复制多份KV Head的内容到显存然后再进行计算。Indexing,即通过传入KV/KV Head索引到Kernel中,然后计算内存地址,直接从内存中读取KV。
MLA(Multi-head Latent Attention)
deepseek使用了MLA。MLA是对GQA的改进。待更新。
YOCO
YOCO是Decoder-Decoder架构,和Decoder-Only架构非常接近,之所以命名为Decoder-Decoder架构,是因为这两个Decoder的含义不是完全一致的。YOCO整体上包括两部分,一部分是Self-Decoder,这个和常见的Decoder Transformers是一样的;另一部分是Cross-Decoder。Self-Decoder负责产生global KV Cache,这个KV Cache会直接被后续的Cross-Decoder使用。这也是后半部分为啥叫做Cross-Decoder的原因,它使用Self-Decoder产生的KV Cache做交叉注意力机制,Cross-Decoder本身不产生KV Cache。
Reference
[1] 一文通透各种注意力:从多头注意力MHA到分组查询注意力GQA、多查询注意力MQA
[2] Transformer系列:注意力机制的优化,MQA和GQA原理简述
[3] Navigating the Attention Landscape: MHA, MQA, and GQA Decoded
[4] 【NLP】(task2)图解attention+transformer(代码讲解)
[5] 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA
[6] 理解Attention:从起源到MHA,MQA和GQA
[7] [KV Cache优化]MQA/GQA/YOCO/CLA笔记: 层内和层间KV Cache共享