在Transformer模型中,注意力机制是其核心组成部分,它允许模型在处理序列数据时关注输入序列的不同部分。传统的多头注意力(MHA: Multi-head Attention)通过并行运行多个注意力头来捕捉输入序列中不同方面的关联。然而,MHA在计算和内存效率方面存在一定的局限性,尤其是在处理长序列时。DeepSeek V3中的多头潜在注意力(MLA: Multi-head Latent Attention)旨在解决这些问题,提供一种更高效的注意力机制。本文咱们谈谈MLA与MHA的特性,示例性实现代码,以及他们之间关键差异。
一、MHA: Multi-head Attention (多头注意力)
MHA通过将输入queries (Q)、keys (K) 和 values (V) 分别投影到多个不同的子空间(即不同的“头”),然后在每个子空间中独立计算注意力。最后,将所有头的输出拼接起来并进行线性变换,得到最终的注意力输出。MHA的计算过程可以概括为以下步骤:
- 线性投影:将Q、K、V分别通过不同的线性变换矩阵投影到不同的子空间。
- 缩放点积注意力:在每个子空间中计算缩放点积注意力。
- 拼接:将所有头的输出拼接起来。
- 线性变换:对拼接后的结果进行线性变换,得到最终的输出。
二、MLA: Multi-head Latent Attention (多头潜在注意力)
MLA的核心思想是使用低秩分解(LoRA)来近似Key和Value的投影,并使用旋转位置编码(RoPE)来编码位置信息。这使得MLA在参数效率和计算效率上优于MHA。MLA的主要设计目标是提高效率。MLA的计算过程可以概括为:
低秩分解(LoRA)应用于Key和Value的投影:
MLA使用低秩矩阵来近似Key和Value的投影矩阵。这意味着将一个大的投影矩阵分解为两个小矩阵的乘积。具体来说,它使用两个线性层wkv_a和wkv_b来代替一个大的Key/Value投影矩阵。wkv_a将输入投影到一个低维空间(kv_lora_rank),然后wkv_b将其投影回原始维度。这种方法显著减少了需要训练的参数数量,从而降低了内存占用和计算复杂度。
旋转位置编码(RoPE)应用于Query和Key:
MLA使用RoPE来为Query和Key添加位置信息。RoPE通过旋转Query和Key向量来实现,旋转的角度取决于它们在序列中的位置。这种方法不需要额外的参数,并且可以很好地泛化到不同的序列长度。
Query的LoRA(可选):
MLA还允许对Query使用LoRA,这与Key和Value的LoRA类似,可以进一步减少参数数量。
优化的注意力计算(吸收式实现 - “absorb” impl):
DeepSeek V3的MLA包含一种优化的注意力计算方式,称为“吸收式”实现。这种方法通过将部分线性变换“吸收”到注意力计算中,进一步提高了效率。具体来说,它将wkv_b的部分计算融入到注意力分数计算中,减少了后续的矩阵乘法操作。下面是MLA的示例实现。 `
三、MLA与MHA的关键差异
以下是使用PyTorch实现MHA和MLA的示例代码:
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, q, k, v, mask=None):
attn = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
if mask is not None:
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = torch.softmax(attn, dim=-1)
output = torch.matmul(attn, v)
return output
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
q = self.W_q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
k = self.W_k(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
v = self.W_v(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
attn_output = self.scaled_dot_product_attention(q, k, v, mask)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.W_o(attn_output)
return output
import torch
import torch.nn as nn
class Frequencies(nn.Module):
def __init__(self, max_seq_length,head_dim):
super().__init__()
assert head_dim%2==0, 'head_dim should be even'
# m values for different positions in sequence
m = torch.arange(0,max_seq_length)
# theta values for different index in token vector
theta = 1/(10000**2*torch.arange(0,head_dim//2)/head_dim)
#all possible combinations for m and theta
freq = torch.outer(m,theta)
#converting freq to polar
complex_freq = torch.polar(torch.ones_like(freq),freq)
self.register_buffer('complex_freq', complex_freq.unsqueeze(0).unsqueeze(2))
def forward(self):
return self.complex_freq
def rope(x,complex_freq):
b ,s, h, d = x.shape
x = x.view(b, s, h, -1, 2)
x = torch.view_as_complex(x)
x = x * complex_freq[:,s,:,:]
x = torch.view_as_real(x)
x = x.view(b,s,h,d)
return x
class MultiHeadLatentAttention(nn.Module):
def __init__(self, hidden_dim:int, heads:int,
v_head_dim:int, kv_rank:int, query_rank:int, rope_qk_dim:int):
super().__init__()
self.heads = heads
self.hidden_dim = hidden_dim
self.rope_qk_dim = rope_qk_dim
self.kv_rank = kv_rank
self.query_rank = query_rank
self.v_head_dim = v_head_dim
assert self.hidden_dim % self.heads == 0 ,"hidden_dim must be divisible by heads"
self.head_dim = self.hidden_dim // self.heads
self.qk_head_dim = self.head_dim + self.rope_qk_dim # head_dim for qk
# down and up projection matrix (query)
self.wq_d = nn.Linear(self.hidden_dim,self.query_rank,bias=False)
self.q_norm = nn.RMSNorm(self.query_rank)
self.wq_u = nn.Linear(self.query_rank, self.heads * self.qk_head_dim,bias=False)
# down and up projection matrix (key_value)
self.wkv_d = nn.Linear(self.hidden_dim, (self.kv_rank + self.rope_qk_dim) ,bias=False)
self.kv_norm = nn.RMSNorm(self.kv_rank)
self.wkv_u = nn.Linear(self.kv_rank, self.heads * (self.head_dim + self.v_head_dim) ,bias=False)
#output_linear_layer
self.wo = nn.Linear(self.heads * self.v_head_dim ,self.hidden_dim,bias=False)
self.scale = self.qk_head_dim ** -0.5
def forward(self, x: torch.Tensor, rope_freq: torch.Tensor, mask: torch.Tensor = None):
b,s,d = x.shape
# (b,s,d) -> (b,s, n_h * qk_dim)
q = self.wq_u(self.q_norm(self.wq_d(x)))
q = q.view(-1,s,self.heads,self.qk_head_dim)
#(b,s, n_h * qk_dim) -> (b,s,n_h,d_h) , (b,s,n_h, d_r)
q , q_rope = torch.split(q, [self.head_dim,self.rope_qk_dim],dim=-1)
q_rope = rope(q_rope,rope_freq)
q = torch.cat([q , q_rope],dim=-1)
kv_c = self.wkv_d(x)
#(b,s, n_h * qk_dim) -> (b,s,n_h * d_h) , (b,s,n_h * d_r)
kv_c, k_rope = torch.split(kv_c,[self.kv_rank,self.rope_qk_dim],dim=-1)
k_rope = rope(k_rope.unsqueeze(2),rope_freq)
k_rope = k_rope.expand(-1,-1,self.heads,-1)
kv = self.wkv_u(self.kv_norm(kv_c))
kv = kv.view(-1,s,self.heads,(self.head_dim + self.v_head_dim))
k ,v = torch.split(kv,[self.head_dim , self.v_head_dim],dim=-1)
# (b, s ,n_h, qk_dim)
k = torch.cat([k , k_rope], dim=-1)
# attention mechanism
q,k,v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)
attention_scores = q @ k.transpose(2,3) * self.scale
# causal_masking
if mask is not None :
attention_scores = attention_scores.masked_fill(mask==0,-torch.inf)
attention_weights = torch.softmax(attention_scores,dim=-1)
out = (attention_weights @ v).transpose(1,2)
out = out.contiguous().view(-1,s,self.heads * self.v_head_dim)
return self.wo(out)
四、如何系统学习掌握AI大模型?
AI大模型作为人工智能领域的重要技术突破,正成为推动各行各业创新和转型的关键力量。抓住AI大模型的风口,掌握AI大模型的知识和技能将变得越来越重要。
学习AI大模型是一个系统的过程,需要从基础开始,逐步深入到更高级的技术。
这里给大家精心整理了一份
全面的AI大模型学习资源
,包括:AI大模型全套学习路线图(从入门到实战)、精品AI大模型学习书籍手册、视频教程、实战学习、面试题等,资料免费分享
!
1. 成长路线图&学习规划
要学习一门新的技术,作为新手一定要先学习成长路线图,方向不对,努力白费。
这里,我们为新手和想要进一步提升的专业人士准备了一份详细的学习成长路线图和规划。可以说是最科学最系统的学习成长路线。
2. 大模型经典PDF书籍
书籍和学习文档资料是学习大模型过程中必不可少的,我们精选了一系列深入探讨大模型技术的书籍和学习文档,它们由领域内的顶尖专家撰写,内容全面、深入、详尽,为你学习大模型提供坚实的理论基础。(书籍含电子版PDF)
3. 大模型视频教程
对于很多自学或者没有基础的同学来说,书籍这些纯文字类的学习教材会觉得比较晦涩难以理解,因此,我们提供了丰富的大模型视频教程,以动态、形象的方式展示技术概念,帮助你更快、更轻松地掌握核心知识。
4. 2024行业报告
行业分析主要包括对不同行业的现状、趋势、问题、机会等进行系统地调研和评估,以了解哪些行业更适合引入大模型的技术和应用,以及在哪些方面可以发挥大模型的优势。
5. 大模型项目实战
学以致用 ,当你的理论知识积累到一定程度,就需要通过项目实战,在实际操作中检验和巩固你所学到的知识,同时为你找工作和职业发展打下坚实的基础。
6. 大模型面试题
面试不仅是技术的较量,更需要充分的准备。
在你已经掌握了大模型技术之后,就需要开始准备面试,我们将提供精心整理的大模型面试题库,涵盖当前面试中可能遇到的各种技术问题,让你在面试中游刃有余。
全套的AI大模型学习资源已经整理打包,有需要的小伙伴可以
微信扫描下方CSDN官方认证二维码
,免费领取【保证100%免费
】