MQA(Multi-Query Attention)详解

论文名称:Fast Transformer Decoding: One Write-Head is All You Need

论文地址:https://arxiv.org/abs/1911.02150v1

        MQA(Multi-Query Attention)是Google团队在2019年提出的,是MHA (Multi-head Attention,多头注意力机制)的一种变体,也是用于自回归解码的一种注意力机制。

        传统的MHA是将输入划分为多个Head,并为每个Head独立计算注意力。在MHA中的,Q、K、V会根据每个head做不同的转换(模拟:每个Head都有自己的感知域/parameter sets,可以独立学习输入中的不同特性)。这在Head数量较多时候可能会存在计算密集的问题。

        而与MHA 不同的是,MQA 让所有的Head之间共享同样的一份 K 和 V 矩阵(意味K和V的计算唯一),只让 Q 保留了原始多头的性质(每个Head存在不同的转换),从而大大减少 K 和 V 矩阵的参数量以及KV Cache的显存占用,以此来达到提升推理速度,但是会带来精度上的损失。技术被大量应用于大预言模型,如ChatGLM2。

从代码角度来看,形式如下:

K_shared = WK * K
V_shared = WV * V

for i in range(num_heads):
    Qi = WQi * Q
    ...
    ...

下面一段代码来自于下面这个链接的作者的实现chatGLM2中的Multi Query Attention_multi-query attention-CSDN博客

源码请看huggingface的transformers包中的bertselfattention源码实现。

    class MultiQuerySelfAttention(nn.Module):
    def __init__(self, num_attention_heads, hidden_size):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
 
        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.attention_head_size)
        self.value = nn.Linear(hidden_size, self.attention_head_size)
 
        self.dropout = nn.Dropout(0.1)
 
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)
 
    def forward(self,hidden_states):
        # hidden_states (B, L, D)
        mixed_query_layer = self.query(hidden_states)
        # query_layer  (B, h, L, d)
        # 在此处,将query划分为多头[batch_size, head_num, 序列长度, embedding长度]
        query_layer = self.transpose_for_scores(mixed_query_layer)
 
        # 每个key、value head参数都是共享的,只计算一次
        key = self.key(hidden_states)
        #key_layer  (B, 1, L, d)
        key_layer = key.unsqueeze(1)
        value = self.value(hidden_states)
        # value_layer  (B, 1, L, d)
        value_layer = value.unsqueeze(1)
 
        # key_layer  (B, 1, d, L)
        key_layer = key_layer.transpose(-1, -2)
        #广播算法 (B, h, L, d) * (B, 1, d, L) => (B, h, L, d) * (B, h, d, L) = (B, h, L, L)
        attention_scores = torch.matmul(query_layer, key_layer)
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        #广播算法 (B, h, L, L) * (B, 1, L, d) =>(B, h, L, L) * (B, h, L, d)= (B, h, L, d)
        context_layer = torch.matmul(attention_probs, value_layer)
        #(B, h, L, d) => (B, L, h, d)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        # (B,L, h*d) => (B,L,D)
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        # (B,L, h*d) => (B,L,D)
        context_layer = context_layer.view(new_context_layer_shape)

        return context_layer

稍微补充一下:原论文中的MQA伪代码如下,和自注意力的MQA实现有些区别,个人猜测如下

        这里简单理解下,一般情况下我们讲的都是自注意力XXX,比如自注意力MHA,这时Q、K、V都来自于输入X;但是,论文中讲述的应该是纯粹的MHA和MQA,此时构成Q和K的输入就不同。(猜想来自于传统注意力机制,该机制多应用于seq-seq任务)

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值