MHA、MQA、GQA注意力的介绍和代码实现

MHA、MQA、GQA注意力的介绍和代码实现

1.总结

  • MHA(Multi Head Attention) 中,每个头有自己单独的 key-value 对;标准的多头注意力机制,h个Query、Key 和 Value 矩阵。
  • MQA(Multi Query Attention) 中只会有一组 key-value 对;多查询注意力的一种变体,也是用于自回归解码的一种注意力机制。与MHA不同的是,MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量
  • GQA(Grouped Query Attention)中,会对 attention 进行分组操作,query 被分为 N 组,每个组共享一个 Key 和 Value 矩阵GQA将查询头分成G组,每个组共享一个Key 和 Value 矩阵。GQA-G是指具有G组的grouped-query attention。GQA-1具有单个组,因此具有单个Key 和 Value,等效于MQA。而GQA-H具有与头数相等的组,等效于MHA。

GQA-N 是指具有 N 组的 Grouped Query Attention。GQA-1具有单个组,因此具有单个Key 和 Value,等效于MQA。而GQA-H具有与头数相等的组,等效于MHA。

GQA介于MHA和MQA之间。GQA 综合 MHA 和 MQA ,既不损失太多性能,又能利用 MQA 的推理加速。不是所有 Q 头共享一组 KV,而是分组一定头数 Q 共享一组 KV,比如上图中就是两组 Q 共享一组 KV。

2.代码实现

2.1 MHA

多头注意力机制是Transformer模型中的核心组件。在其设计中,"多头"意味着该机制并不只计算一种注意力权重,而是并行计算多种权重,每种权重都从不同的“视角”捕获输入的不同信息。

  1. 为输入序列中的每个元素计算q, k, v,这是通过将输入此向量与三个权重矩阵相乘实现的:
    q = x W q k = x W k v = x W v \begin{aligned} q & =x W_{q} \\ k & =x W_{k} \\ v & =x W_{v}\end{aligned} qkv=xWq=xWk=xWv
    其中, x x x是输入词向量, W q W_q Wq, W k W_k Wk W v W_v Wv是q, k, v的权重矩阵
  2. 计算q, k 注意力得分: score ⁡ ( q , k ) = q ⋅ k T d k \operatorname{score}(q, k)=\frac{q \cdot k^{T}}{\sqrt{d_{k}}} score(q,k)=dk qkT,其中, d k d_k dk是k的维度
  3. 使用softmax得到注意力权重: Attention ⁡ ( q , K ) = softmax ⁡ ( score ⁡ ( q , k ) ) \operatorname{Attention}(q, K)=\operatorname{softmax}(\operatorname{score}(q, k)) Attention(q,K)=softmax(score(q,k))
  4. 使用注意力权重和v,计算输出: O u t p u t = Attention ⁡ ( q , K ) ⋅ V Output =\operatorname{Attention}(q, K) \cdot V Output=Attention(q,K)V
  5. 拼接多头输出,乘以 W O W_O WO,得到最终输出: M u l t i H e a d O u t p u t = C o n c a t ( O u t p u t 1 , O u t p u t 2 , … , O u t p u t H ) W O MultiHeadOutput = Concat \left(\right. Output ^{1}, Output ^{2}, \ldots, Output \left.^{H}\right) W_{O} MultiHeadOutput=Concat(Output1,Output2,,OutputH)WO

代码实现

import torch
from torch import nn

class MutiHeadAttention(torch.nn.Module):  # 继承 PyTorch 的 nn.Module
    def __init__(self, hidden_size, num_heads):  # 初始化函数,接收隐藏层维度和注意力头数
        super(MutiHeadAttention, self).__init__()  # 调用父类构造函数
        self.num_heads = num_heads  # 保存注意力头数
        self.head_dim = hidden_size // num_heads  # 计算每个注意力头的维度
        
        # 初始化 Q、K、V 投影矩阵,将输入投影到不同的空间
        self.q_linear = nn.Linear(hidden_size, hidden_size)  # 线性变换层,映射 Q
        self.k_linear = nn.Linear(hidden_size, hidden_size)  # 线性变换层,映射 K
        self.v_linear = nn.Linear(hidden_size, hidden_size)  # 线性变换层,映射 V
        
        # 输出线性变换层,将多头注意力结果拼接后映射回原始维度
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, hidden_state, attention_mask=None):  # 前向传播函数
        batch_size = hidden_state.size()[0]  # 获取 batch_size
        
        query = self.q_linear(hidden_state)  # 计算查询(Query)
        key = self.k_linear(hidden_state)  # 计算键(Key)
        value = self.v_linear(hidden_state)  # 计算值(Value)
        
        query = self.split_head(query)  # 拆分多头结构
        key = self.split_head(key)  # 拆分多头结构
        value = self.split_head(value)  # 拆分多头结构
        
        # 计算注意力分数,使用缩放点积注意力机制
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
        
        if attention_mask is not None:  # 如果提供了注意力掩码
            attention_scores += attention_mask * -1e-9  # 施加掩码,屏蔽无关部分
        
        # 对注意力分数进行 softmax 归一化,得到注意力权重
        attention_probs = torch.softmax(attention_scores, dim=-1)
        
        output = torch.matmul(attention_probs, value)  # 计算注意力加权后的输出
        
        # 对注意力输出进行拼接
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
        
        output = self.o_linear(output)  # 通过输出线性层
        
        return output  # 返回最终输出

    def split_head(self, x):  # 拆分多头方法
        batch_size = x.size()[0]  # 获取 batch_size
        return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # 变换形状以适应多头注意力

    
    
        
        

2.2 MQA

上图最右侧,直观上就是在计算多头注意力的时候,query仍然进行分头,和多头注意力机制相同,而key和value只有一个头。

正常情况在计算多头注意力分数的时候,query、key的维度是相同的,所以可以直接进行矩阵乘法,但是在多查询注意力(MQA)中,query的维度为 [batch_size, num_heads, seq_len, head_dim],key和value的维度为 [batch_size, 1, seq_len, head_dim]。这样就无法直接进行矩阵的乘法,为了完成这一乘法,可以采用torch的广播乘法

# 导入torch库
import torch
# 从torch库中导入神经网络模块nn
from torch import nn

# 定义多查询注意力模块,继承自torch.nn.Module
class MutiQueryAttention(torch.nn.Module):
    # 初始化函数,hidden_size为隐藏层大小,num_heads为注意力头的数量
    def __init__(self, hidden_size, num_heads):
        # 调用父类的初始化方法
        super(MutiQueryAttention, self).__init__()
        # 保存注意力头的数量
        self.num_heads = num_heads
        # 计算每个注意力头的维度(假设hidden_size可以被num_heads整除)
        self.head_dim = hidden_size // num_heads
        
        ## 初始化Q、K、V的线性投影层
        # 定义用于生成查询向量的全连接层,输入和输出的维度均为hidden_size
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        # 定义用于生成键向量的全连接层,输出维度为head_dim
        self.k_linear = nn.Linear(hidden_size, self.head_dim)  ###
        # 定义用于生成值向量的全连接层,输出维度为head_dim
        self.v_linear = nn.Linear(hidden_size, self.head_dim)  ###
        
        ## 初始化输出全连接层,用于整合各注意力头的输出
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        
    # 定义前向传播函数,hidden_state为输入的隐藏状态,attention_mask为可选的注意力掩码
    def forward(self, hidden_state, attention_mask=None):
        # 获取批次大小,从hidden_state的第一个维度获得
        batch_size = hidden_state.size()[0]
        
        # 通过q_linear全连接层生成查询向量
        query = self.q_linear(hidden_state)
        # 通过k_linear全连接层生成键向量
        key = self.k_linear(hidden_state)
        # 通过v_linear全连接层生成值向量
        value = self.v_linear(hidden_state)
        
        # 将查询向量拆分为多个注意力头
        query = self.split_head(query)
        # 将键向量拆分为多个注意力头,传入head_num参数为1
        key = self.split_head(key, 1)
        # 将值向量拆分为多个注意力头,传入head_num参数为1
        value = self.split_head(value, 1)
        
        ## 计算注意力分数
        # 计算查询和键向量的点积,并对最后一个维度进行转置,再除以head_dim的平方根进行缩放
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
        
        # 如果提供了注意力掩码,则将其应用于注意力分数(乘以一个很小的负数因子)
        if attention_mask != None:
            attention_scores += attention_mask * -1e-9
        
        ## 对注意力分数进行归一化
        # 对注意力分数沿着最后一个维度使用softmax函数归一化,得到注意力概率
        attention_probs = torch.softmax(attention_scores, dim=-1)
        
        # 用归一化的注意力概率对值向量进行加权求和,得到注意力输出
        output = torch.matmul(attention_probs, value)
        
        # 将输出张量的最后两个维度进行转置,调用contiguous保证内存连续性,
        # 再reshape为(batch_size, 序列长度, head_dim * num_heads)
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
        
        # 将整合后的输出通过输出全连接层进行最后的线性变换
        output = self.o_linear(output)
        
        # 返回最终的注意力输出
        return output
        
    # 定义辅助函数split_head,用于将输入张量拆分成多个注意力头
    def split_head(self, x, head_num=None):
        # 获取批次大小,从x的第一个维度获得
        batch_size = x.size()[0]
        
        # 如果未指定head_num,则使用初始化时定义的num_heads进行拆分
        if head_num == None:
            # 将x重塑为 (batch_size, 序列长度, num_heads, head_dim) 并交换第1和第2个维度
            return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        else:
            # 如果指定了head_num,则将x重塑为 (batch_size, 序列长度, head_num, head_dim) 并交换第1和第2个维度
            return x.view(batch_size, -1, head_num, self.head_dim).transpose(1, 2)

    
    

相比于多头注意力,多查询注意力在W_k和W_v的维度映射上有所不同,还有就是计算注意力分数采用的是广播机制,计算最后的output也是广播机制,其他的与多头注意力完全相同。

2.3 GQA

GQA将MAQ中的key、value的注意力头数设置为一个能够被原本的注意力头数整除的一个数字,也就是group数。

不同的模型使用GQA有着不同的实现方式,但是总体的思路就是这么实现的,注意,设置的组一定要能够被注意力头数整除。

## 分组注意力查询
import torch
from torch import nn

# 定义一个GroupQueryAttention类,继承自nn.Module
class GroupQueryAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_heads, group_num):
        super(MutiQueryAttention, self).__init__()
        
        # 设置头数、每个头的维度和组数
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.group_num = group_num
        
        # 初始化Q、K、V投影矩阵
        self.q_linear = nn.Linear(hidden_size, hidden_size)  # 查询矩阵Q
        self.k_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)  # 键矩阵K
        self.v_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)  # 值矩阵V
        
        # 输出的线性变换层
        self.o_linear = nn.Linear(hidden_size, hidden_size)

    # 定义前向传播函数
    def forward(self, hidden_state, attention_mask=None):
        batch_size = hidden_state.size()[0]  # 获取批次大小
        
        # 计算Q、K、V
        query = self.q_linear(hidden_state)  # 计算查询向量Q
        key = self.k_linear(hidden_state)  # 计算键向量K
        value = self.v_linear(hidden_state)  # 计算值向量V
        
        # 将Q、K、V拆分成多个头
        query = self.split_head(query)
        key = self.split_head(key, self.group_num)  # 按照组数拆分键
        value = self.split_head(value, self.group_num)  # 按照组数拆分值
        
        # 计算注意力分数
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
        
        # 如果提供了attention_mask,则对注意力分数做遮盖
        if attention_mask != None:
            attention_scores += attention_mask * -1e-9
        
        # 对注意力分数进行softmax归一化,得到注意力权重
        attention_probs = torch.softmax(attention_scores, dim=-1)
        
        # 根据注意力权重加权求值
        output = torch.matmul(attention_probs, value)
        
        # 对输出进行维度转换,并恢复原始形状
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
        
        # 通过输出线性层映射到最终的输出空间
        output = self.o_linear(output)
        
        return output

    # 定义拆分头部的函数
    def split_head(self, x, group_num=None):
        
        # 获取批次大小和序列长度
        batch_size, seq_len = x.size()[:2]
        
        # 如果没有给定group_num,按照头数拆分
        if group_num == None:
            return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        else:
            # 按照给定的组数拆分
            x = x.view(batch_size, -1, group_num, self.head_dim).transpose(1, 2)
            # 扩展x的维度并重新排列,以符合多头注意力的需求
            x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads // group_num * group_num, seq_len, self.head_dim)
            return x

### MHA Multi-Head Attention Mask Implementation Code Example In the context of implementing masking within a multi-head attention mechanism (MHA), it is essential to understand how masks are applied during the computation process. Masks play an important role by allowing certain positions not to contribute to the output, which can be particularly useful in tasks like language modeling where future tokens should not influence current predictions. Below is a Python code snippet demonstrating how one might implement masked multi-head self-attention using PyTorch: ```python import torch import torch.nn as nn import math class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.d_k = d_model // num_heads self.num_heads = num_heads self.q_linear = nn.Linear(d_model, d_model) self.v_linear = nn.Linear(d_model, d_model) self.k_linear = nn.Linear(d_model, d_model) self.out = nn.Linear(d_model, d_model) def forward(self, q, k, v, mask=None): batch_size = q.size(0) # Linear transformations k = self.k_linear(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) v = self.v_linear(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # Calculate compatibility scores and apply scaling factor scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # Apply mask if provided if mask is not None: mask = mask.unsqueeze(1) # For broadcasting across heads dimension scores = scores.masked_fill(mask == 0, float('-inf')) # Softmax along last dimension before applying dropout attn_weights = torch.softmax(scores, dim=-1) output = torch.matmul(attn_weights, v) # Concatenate outputs from different heads together again concat_output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k) # Final linear transformation after concatenation step final_output = self.out(concat_output) return final_output, attn_weights ``` This implementation includes support for optional masking through `mask` parameter passed into the `forward()` method[^1]. When present, this mask will prevent specific elements from contributing towards computing attention weights via setting their corresponding score values very low (`float('-inf')`) so they effectively get ignored when calculating softmax probabilities later on. --related questions-- 1. How does masking work inside transformer models? 2. What modifications would need to occur in order to adapt this code sample specifically for GQA or MQA implementations instead of standard MHA? 3. Can you explain why we use `-inf` value while applying mask in attention mechanisms? 4. Is there any difference between padding masks versus subsequent masks used within transformers architecture? If yes, what distinguishes them apart? 5. Could you provide more details about ziya_finetune project mentioned earlier regarding inference quantization examples?
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值