【AI学习】Transformer深入学习(二):从MHA、MQA、GQA到MLA

前面文章:
Transformer深入学习(一):Sinusoidal位置编码的精妙

一、MHA、MQA、GQA

为了降低KV cache,MQA、GQA作为MHA的变体,很容易理解。
多头注意力(MHA):
多头注意力是一种在Transformer架构中广泛使用的注意力机制,通过将查询、键和值分别投影到多个不同的空间上,然后并行计算这些空间上的注意力得分,从而获得更加丰富和细致的特征表示。

多查询注意力(MQA)
多查询注意力是MHA的一种变种,它通过共享单个key和value头来提升性能,但可能会导致质量下降和训练不稳定。MQA在保持速度的同时提高了模型的推理效率,但在某些情况下可能无法达到与MHA相同的效果。

分组查询注意力(GQA)
分组查询注意力是MQA和MHA之间的过渡方法,旨在同时保持MQA的速度和MHA的质量。GQA通过使用中间数量的键值头(大于一个,小于查询头的数量),实现了性能和速度的平衡。具体来说,GQA通过分组的方式减少了需要处理的头数,从而降低了内存需求和计算复杂度。

分组查询注意力(Grouped-Query Attention,简称GQA)是一种用于提高大模型推理可扩展性的机制。其具体实现机制如下:

1、基本概念:GQA是多头

### MHAGQAMLA 的区别及应用场合 #### 多头注意力机制(Multi-Head Attention, MHA) 多头注意力机制允许模型在不同的表示子空间中并行关注不同位置的信息。每个头独立操作,最终结果通过拼接各头的结果来获得更丰富的特征表达[^1]。 ```python import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, embed_size, num_heads): super(MultiHeadAttention, self).__init__() self.embed_size = embed_size self.num_heads = num_heads # 定义线性变换层 self.query_linear = nn.Linear(embed_size, embed_size) self.key_linear = nn.Linear(embed_size, embed_size) self.value_linear = nn.Linear(embed_size, embed_size) def forward(self, query, key, value): batch_size = query.size(0) # 对输入进行线性变换 Q = self.query_linear(query) K = self.key_linear(key) V = self.value_linear(value) # 将嵌入维度分割成多个头 Q = Q.view(batch_size, -1, self.num_heads, self.embed_size // self.num_heads).transpose(1, 2) K = K.view(batch_size, -1, self.num_heads, self.embed_size // self.num_heads).transpose(1, 2) V = V.view(batch_size, -1, self.num_heads, self.embed_size // self.num_heads).transpose(1, 2) # 计算注意力分数并加权求和 scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embed_size ** 0.5) attention_weights = F.softmax(scores, dim=-1) output = torch.matmul(attention_weights, V) # 合并头部并将结果传递给下一个线性层 output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_size) return output ``` #### 组化查询注意力机制(Grouped Query Attention, GQA) 为了减少计算量,GQA引入了查询分组的概念,即某些查询可以共享相同的键和值矩阵。这减少了重复计算的数量,在大规模数据集上尤其有效[^3]。 ```python class GroupedQueryAttention(nn.Module): def __init__(self, embed_size, num_groups, heads_per_group): super(GroupedQueryAttention, self).__init__() self.embed_size = embed_size self.num_groups = num_groups self.heads_per_group = heads_per_group # 初始化参数... def forward(self, queries, keys, values): # 实现GQA的具体逻辑... pass ``` #### 压缩键值注意力机制(Compressed Key/Value Attention, MLAMLA进一步优化了资源利用效率,通过对键和值向量实施低秩近似压缩处理,从而显著降低了存储开销以及前向传播过程中的运算复杂度。 $$K_{\text{compressed}} = U_K \cdot S_K \cdot V_K^T$$ $$V_{\text{compressed}} = U_V \cdot S_V \cdot V_V^T$$ ```python from scipy.linalg import svd def compress_matrix(matrix, rank): u, s, vh = svd(matrix) compressed = np.dot(u[:, :rank], np.dot(np.diag(s[:rank]), vh[:rank, :])) return compressed class CompressedKeyValueAttention(nn.Module): def __init__(self, embed_size, compression_rank): super(CompressedKeyValueAttention, self).__init__() self.compress_key = lambda k: compress_matrix(k, compression_rank) self.compress_value = lambda v: compress_matrix(v, compression_rank) # 其他初始化... def forward(self, queries, keys, values): compressed_keys = self.compress_key(keys) compressed_values = self.compress_value(values) # 使用压缩后的keys和values继续执行标准的注意力机制... pass ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

bylander

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值