MHA、MQA、GQA注意力的介绍和代码实现
1.总结
- 在 MHA(Multi Head Attention) 中,每个头有自己单独的 key-value 对;标准的多头注意力机制,h个Query、Key 和 Value 矩阵。
- 在 MQA(Multi Query Attention) 中只会有一组 key-value 对;多查询注意力的一种变体,也是用于自回归解码的一种注意力机制。与MHA不同的是,MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。
- 在 GQA(Grouped Query Attention)中,会对 attention 进行分组操作,query 被分为 N 组,每个组共享一个 Key 和 Value 矩阵GQA将查询头分成G组,每个组共享一个Key 和 Value 矩阵。GQA-G是指具有G组的grouped-query attention。GQA-1具有单个组,因此具有单个Key 和 Value,等效于MQA。而GQA-H具有与头数相等的组,等效于MHA。
GQA-N 是指具有 N 组的 Grouped Query Attention。GQA-1具有单个组,因此具有单个Key 和 Value,等效于MQA。而GQA-H具有与头数相等的组,等效于MHA。
GQA介于MHA和MQA之间。GQA 综合 MHA 和 MQA ,既不损失太多性能,又能利用 MQA 的推理加速。不是所有 Q 头共享一组 KV,而是分组一定头数 Q 共享一组 KV,比如上图中就是两组 Q 共享一组 KV。
2.代码实现
2.1 MHA
多头注意力机制是Transformer模型中的核心组件。在其设计中,"多头"意味着该机制并不只计算一种注意力权重,而是并行计算多种权重,每种权重都从不同的“视角”捕获输入的不同信息。
- 为输入序列中的每个元素计算q, k, v,这是通过将输入此向量与三个权重矩阵相乘实现的:
q = x W q k = x W k v = x W v \begin{aligned} q & =x W_{q} \\ k & =x W_{k} \\ v & =x W_{v}\end{aligned} qkv=xWq=xWk=xWv
其中, x x x是输入词向量, W q W_q Wq, W k W_k Wk和 W v W_v Wv是q, k, v的权重矩阵 - 计算q, k 注意力得分: score ( q , k ) = q ⋅ k T d k \operatorname{score}(q, k)=\frac{q \cdot k^{T}}{\sqrt{d_{k}}} score(q,k)=dkq⋅kT,其中, d k d_k dk是k的维度
- 使用softmax得到注意力权重: Attention ( q , K ) = softmax ( score ( q , k ) ) \operatorname{Attention}(q, K)=\operatorname{softmax}(\operatorname{score}(q, k)) Attention(q,K)=softmax(score(q,k))
- 使用注意力权重和v,计算输出: O u t p u t = Attention ( q , K ) ⋅ V Output =\operatorname{Attention}(q, K) \cdot V Output=Attention(q,K)⋅V
- 拼接多头输出,乘以 W O W_O WO,得到最终输出: M u l t i H e a d O u t p u t = C o n c a t ( O u t p u t 1 , O u t p u t 2 , … , O u t p u t H ) W O MultiHeadOutput = Concat \left(\right. Output ^{1}, Output ^{2}, \ldots, Output \left.^{H}\right) W_{O} MultiHeadOutput=Concat(Output1,Output2,…,OutputH)WO
代码实现
import torch
from torch import nn
class MutiHeadAttention(torch.nn.Module): # 继承 PyTorch 的 nn.Module
def __init__(self, hidden_size, num_heads): # 初始化函数,接收隐藏层维度和注意力头数
super(MutiHeadAttention, self).__init__() # 调用父类构造函数
self.num_heads = num_heads # 保存注意力头数
self.head_dim = hidden_size // num_heads # 计算每个注意力头的维度
# 初始化 Q、K、V 投影矩阵,将输入投影到不同的空间
self.q_linear = nn.Linear(hidden_size, hidden_size) # 线性变换层,映射 Q
self.k_linear = nn.Linear(hidden_size, hidden_size) # 线性变换层,映射 K
self.v_linear = nn.Linear(hidden_size, hidden_size) # 线性变换层,映射 V
# 输出线性变换层,将多头注意力结果拼接后映射回原始维度
self.o_linear = nn.Linear(hidden_size, hidden_size)
def forward(self, hidden_state, attention_mask=None): # 前向传播函数
batch_size = hidden_state.size()[0] # 获取 batch_size
query = self.q_linear(hidden_state) # 计算查询(Query)
key = self.k_linear(hidden_state) # 计算键(Key)
value = self.v_linear(hidden_state) # 计算值(Value)
query = self.split_head(query) # 拆分多头结构
key = self.split_head(key) # 拆分多头结构
value = self.split_head(value) # 拆分多头结构
# 计算注意力分数,使用缩放点积注意力机制
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
if attention_mask is not None: # 如果提供了注意力掩码
attention_scores += attention_mask * -1e-9 # 施加掩码,屏蔽无关部分
# 对注意力分数进行 softmax 归一化,得到注意力权重
attention_probs = torch.softmax(attention_scores, dim=-1)
output = torch.matmul(attention_probs, value) # 计算注意力加权后的输出
# 对注意力输出进行拼接
output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
output = self.o_linear(output) # 通过输出线性层
return output # 返回最终输出
def split_head(self, x): # 拆分多头方法
batch_size = x.size()[0] # 获取 batch_size
return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # 变换形状以适应多头注意力
2.2 MQA
上图最右侧,直观上就是在计算多头注意力的时候,query仍然进行分头,和多头注意力机制相同,而key和value只有一个头。
正常情况在计算多头注意力分数的时候,query、key的维度是相同的,所以可以直接进行矩阵乘法,但是在多查询注意力(MQA)中,query的维度为 [batch_size, num_heads, seq_len, head_dim]
,key和value的维度为 [batch_size, 1, seq_len, head_dim]
。这样就无法直接进行矩阵的乘法,为了完成这一乘法,可以采用torch的广播乘法
# 导入torch库
import torch
# 从torch库中导入神经网络模块nn
from torch import nn
# 定义多查询注意力模块,继承自torch.nn.Module
class MutiQueryAttention(torch.nn.Module):
# 初始化函数,hidden_size为隐藏层大小,num_heads为注意力头的数量
def __init__(self, hidden_size, num_heads):
# 调用父类的初始化方法
super(MutiQueryAttention, self).__init__()
# 保存注意力头的数量
self.num_heads = num_heads
# 计算每个注意力头的维度(假设hidden_size可以被num_heads整除)
self.head_dim = hidden_size // num_heads
## 初始化Q、K、V的线性投影层
# 定义用于生成查询向量的全连接层,输入和输出的维度均为hidden_size
self.q_linear = nn.Linear(hidden_size, hidden_size)
# 定义用于生成键向量的全连接层,输出维度为head_dim
self.k_linear = nn.Linear(hidden_size, self.head_dim) ###
# 定义用于生成值向量的全连接层,输出维度为head_dim
self.v_linear = nn.Linear(hidden_size, self.head_dim) ###
## 初始化输出全连接层,用于整合各注意力头的输出
self.o_linear = nn.Linear(hidden_size, hidden_size)
# 定义前向传播函数,hidden_state为输入的隐藏状态,attention_mask为可选的注意力掩码
def forward(self, hidden_state, attention_mask=None):
# 获取批次大小,从hidden_state的第一个维度获得
batch_size = hidden_state.size()[0]
# 通过q_linear全连接层生成查询向量
query = self.q_linear(hidden_state)
# 通过k_linear全连接层生成键向量
key = self.k_linear(hidden_state)
# 通过v_linear全连接层生成值向量
value = self.v_linear(hidden_state)
# 将查询向量拆分为多个注意力头
query = self.split_head(query)
# 将键向量拆分为多个注意力头,传入head_num参数为1
key = self.split_head(key, 1)
# 将值向量拆分为多个注意力头,传入head_num参数为1
value = self.split_head(value, 1)
## 计算注意力分数
# 计算查询和键向量的点积,并对最后一个维度进行转置,再除以head_dim的平方根进行缩放
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
# 如果提供了注意力掩码,则将其应用于注意力分数(乘以一个很小的负数因子)
if attention_mask != None:
attention_scores += attention_mask * -1e-9
## 对注意力分数进行归一化
# 对注意力分数沿着最后一个维度使用softmax函数归一化,得到注意力概率
attention_probs = torch.softmax(attention_scores, dim=-1)
# 用归一化的注意力概率对值向量进行加权求和,得到注意力输出
output = torch.matmul(attention_probs, value)
# 将输出张量的最后两个维度进行转置,调用contiguous保证内存连续性,
# 再reshape为(batch_size, 序列长度, head_dim * num_heads)
output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
# 将整合后的输出通过输出全连接层进行最后的线性变换
output = self.o_linear(output)
# 返回最终的注意力输出
return output
# 定义辅助函数split_head,用于将输入张量拆分成多个注意力头
def split_head(self, x, head_num=None):
# 获取批次大小,从x的第一个维度获得
batch_size = x.size()[0]
# 如果未指定head_num,则使用初始化时定义的num_heads进行拆分
if head_num == None:
# 将x重塑为 (batch_size, 序列长度, num_heads, head_dim) 并交换第1和第2个维度
return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
else:
# 如果指定了head_num,则将x重塑为 (batch_size, 序列长度, head_num, head_dim) 并交换第1和第2个维度
return x.view(batch_size, -1, head_num, self.head_dim).transpose(1, 2)
相比于多头注意力,多查询注意力在W_k和W_v的维度映射上有所不同,还有就是计算注意力分数采用的是广播机制,计算最后的output也是广播机制,其他的与多头注意力完全相同。
2.3 GQA
GQA将MAQ中的key、value的注意力头数设置为一个能够被原本的注意力头数整除的一个数字,也就是group数。
不同的模型使用GQA有着不同的实现方式,但是总体的思路就是这么实现的,注意,设置的组一定要能够被注意力头数整除。
## 分组注意力查询
import torch
from torch import nn
# 定义一个GroupQueryAttention类,继承自nn.Module
class GroupQueryAttention(torch.nn.Module):
def __init__(self, hidden_size, num_heads, group_num):
super(MutiQueryAttention, self).__init__()
# 设置头数、每个头的维度和组数
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.group_num = group_num
# 初始化Q、K、V投影矩阵
self.q_linear = nn.Linear(hidden_size, hidden_size) # 查询矩阵Q
self.k_linear = nn.Linear(hidden_size, self.group_num * self.head_dim) # 键矩阵K
self.v_linear = nn.Linear(hidden_size, self.group_num * self.head_dim) # 值矩阵V
# 输出的线性变换层
self.o_linear = nn.Linear(hidden_size, hidden_size)
# 定义前向传播函数
def forward(self, hidden_state, attention_mask=None):
batch_size = hidden_state.size()[0] # 获取批次大小
# 计算Q、K、V
query = self.q_linear(hidden_state) # 计算查询向量Q
key = self.k_linear(hidden_state) # 计算键向量K
value = self.v_linear(hidden_state) # 计算值向量V
# 将Q、K、V拆分成多个头
query = self.split_head(query)
key = self.split_head(key, self.group_num) # 按照组数拆分键
value = self.split_head(value, self.group_num) # 按照组数拆分值
# 计算注意力分数
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
# 如果提供了attention_mask,则对注意力分数做遮盖
if attention_mask != None:
attention_scores += attention_mask * -1e-9
# 对注意力分数进行softmax归一化,得到注意力权重
attention_probs = torch.softmax(attention_scores, dim=-1)
# 根据注意力权重加权求值
output = torch.matmul(attention_probs, value)
# 对输出进行维度转换,并恢复原始形状
output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
# 通过输出线性层映射到最终的输出空间
output = self.o_linear(output)
return output
# 定义拆分头部的函数
def split_head(self, x, group_num=None):
# 获取批次大小和序列长度
batch_size, seq_len = x.size()[:2]
# 如果没有给定group_num,按照头数拆分
if group_num == None:
return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
else:
# 按照给定的组数拆分
x = x.view(batch_size, -1, group_num, self.head_dim).transpose(1, 2)
# 扩展x的维度并重新排列,以符合多头注意力的需求
x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads // group_num * group_num, seq_len, self.head_dim)
return x