Multi Query Attention & Group Query Attention

Multi Query Attention(MQA)在2019年就被提出来了,用于推理加速,但在当时并没有受到很多关注,毕竟一张2080就能跑Bert-base了。随着LLM的大火,MQA所带来的收益得以放大。

思路

Multi Query Attention(MQA)跟Multi Head Attention(MHA)只有一词之差,但其思路非常简单,几乎跟MHA一致:

model

MHA的Query、Key、Value分拆成8个头,每个头进行self-attention运算,而MQA是Query分成8个头,每个头共享一组Key和Value

MHA: Q, K, V = (512, 768), # seq_len, hidden_dim
			拆成8个头:
			Q : (8, 512, 96) 
			k, v: (8, 512, 96)
MQA: 
 Q -> (512, 768) 
 K -> (512, 96)
 v -> (512, 96)
把Q拆成8个头:
Q: (8, 512, 96)
K, V:(512, 96)

代码实现

  • MHA
...
self.Wqkv = nn.Linear( 
            d_model,
            d_model * 3,
            device=device,
        )
...

d_model * 3 拆成3个768维

  • MQA
...
self.Wqkv = nn.Linear( 
            d_model,
            d_model + 2 * self.head_dim,
            device=device,
        )
...

d_model + 2 * self.head_dim 拆成1个768维 + 2个96维

可以看到参数数量大幅减少。

实验结果

实验指标略微降低,但推理加速非常明显。

result

Group Query Attention

Q拆分成8个头,K和V分别拆成4个头,然后对应进行attention运算。


参考

Here is a simple implementation of multi-head attention in PyTorch: ```python import torch import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads self.query = nn.Linear(d_model, d_model) self.key = nn.Linear(d_model, d_model) self.value = nn.Linear(d_model, d_model) self.fc = nn.Linear(d_model, d_model) def split_heads(self, x, batch_size): x = x.view(batch_size, -1, self.num_heads, self.head_dim) x = x.permute(0, 2, 1, 3) return x def forward(self, query, key, value, mask=None): batch_size = query.size(0) # linear transformations query = self.query(query) key = self.key(key) value = self.value(value) # split into multiple heads query = self.split_heads(query, batch_size) key = self.split_heads(key, batch_size) value = self.split_heads(value, batch_size) # dot product attention scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim).float()) if mask is not None: scores = scores.masked_fill(mask == 0, float("-inf")) attention = torch.softmax(scores, dim=-1) x = torch.matmul(attention, value) # concatenate attention heads x = x.permute(0, 2, 1, 3).contiguous() x = x.view(batch_size, -1, self.d_model) # final linear transformation x = self.fc(x) return x ``` This implementation takes as input a `d_model` dimension tensor and splits it into `num_heads` attention heads. The `query`, `key`, and `value` matrices are linearly transformed and split into heads as well. Then, the dot product attention is calculated and the attention heads are concatenated and linearly transformed again. To use this module in your Transformer, you can simply call it like this: ```python attn = MultiHeadAttention(d_model=512, num_heads=8) output = attn(query, key, value) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值