1.整体流程
先上一张图来整体理解下MLA的计算过程
2.实现代码
import math
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
# RMSNorm的参数g
self.weight = nn.Parameter(torch.ones(hidden_size))
# 防止分母为0
self.variance_epsilon = eps
def forward(self, hidden_states):
hidden_states = hidden_states.float()
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states*torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.float()
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((x1, x2), dim=-1)
def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q*cos) + (rotate_half(q)*sin)
k_embed = (k*cos) + (rotate_half(k)*cos)
return q_embed, k_embed
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=1024):
super(RotaryEmbedding, self).__init__()
self.dim = dim
self.max_seq_len = max_seq_len
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float()/dim))
t= torch.arange(max_seq_len).float().unsqueeze(1)
freqs = t @ inv_freq.unsqueeze(0)
freqs = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", freqs.cos())
self.register_buffer("sin_cached", freqs.sin())
def forward(self, q, k):
cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)
sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)
return apply_rotate_pos_emb(q, k, cos, sin)
class MLA(nn.Module):
def __init__(self,
dim,
n_heads,
q_lora_rank,
kv_lora_rank,
qk_nope_head_dim,
qk_rope_head_dim,
v_head_dim,
max_seq_len,
max_batch_size):
super().__init__()
# 隐藏层维度
self.dim = dim
# attention head数
self.n_heads = n_heads
# q低秩压缩到的维度
self.q_lora_rank = q_lora_rank
# k/v低秩压缩到的维度
self.kv_lora_rank = kv_lora_rank
# q/k不带旋转位置编码的维度
self.qk_nope_head_dim = qk_nope_head_dim
# q/k带旋转位置编码的维度
self.qk_rope_head_dim = qk_rope_head_dim
# q/k的总维度
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
# v的维度
self.v_head_dim = v_head_dim
self.max_seq_len = max_seq_len
self.max_batch_size = max_batch_size
self.wq_a = nn.Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank)
self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads*self.qk_head_dim)
self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank)
self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads*(self.qk_nope_head_dim + self.v_head_dim))
self.wo = nn.Linear(self.n_heads*self.v_head_dim, self.dim)
self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)
self.register_buffer("kv_cache", torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank))
self.register_buffer("pe_cache", torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim))
def forward(self, x, mask=None):
bs, seq_len, _ = x.shape
# [bs, seq_len, q_lora_rank]
q = self.wq_a(x)
# [bs, seq_len, q_lora_rank]
q = self.q_norm(q)
# [bs, seq_len, n_heads*(qk_nope_head_dim+qk_rope_head_dim)]
q = self.wq_b(q)
# [bs, seq_len, n_heads, (qk_nope_head_dim+qk_rope_head_dim)]
q = q.view(bs, seq_len, self.n_heads, self.qk_head_dim)
# 按照最后一个维度进行切分
# --> [bs, seq_len, n_heads, qk_nope_head_dim]
# --
# [bs, seq_len, n_heads, (qk_nope_head_dim+qk_rope_head_dim)] --
# --
# --> [bs, seq_len, n_heads, qk_rope_head_dim]
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# [bs, seq_len, kv_lora_rank + qk_rope_head_dim]
kv = self.wkv_a(x)
# 按照最后一个维度进行切分
# --> [bs, seq_len, kv_lora_rank]
# --
# [bs, seq_len, kv_lora_rank + qk_rope_head_dim] --
# --
# --> [bs, seq_len, qk_rope_head_dim]
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
# 和q的维度保持一致,[bs, seq_len, 1, qk_rope_head_dim]
k_pe = k_pe.unsqueeze(2)
# 旋转位置编码
q_pe, k_pe = self.rotary_emb(q_pe, k_pe)
# 重新压缩为原来的维度 [bs, seq_len, qk_rope_head_dim]
k_pe = k_pe.squeeze(2)
kv = self.kv_norm(kv)
# 缓存共同作用于k和v的矩阵,该矩阵用于对k和v升维
self.kv_cache[:bs, :seq_len, :] = kv
# 缓存用于计算旋转位置编码部分的k矩阵
self.pe_cache[:bs, :seq_len, :] = k_pe
# [n_heads*(qk_nope_head_dim + v_head_dim), kv_lora_rank]
wkv_b = self.wkv_b.weight
# [n_heads, (qk_nope_head_dim + v_head_dim), kv_lora_rank]
wkv_b = wkv_b.view(self.n_heads, -1, self.kv_lora_rank)
# #################################MLA的核心#################################
# q_nope可简单理解成x*w_q,然后再乘以w_k,即x*w_q*w_k,计算结果的shape为[bs, seq_len, n_heads, qk_nope_head_dim)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
# 再乘以k,这里的k是降维之后的x,即对x作用了一个降维矩阵wkv_a,计算结果的shape为[bs, seq_len, n_heads, seq_len]
# 得到非旋转位置编码部分q和k的相似度
scores_nope = torch.einsum("bshc, btc->bsht", q_nope, self.kv_cache[:bs, :seq_len, :])
# 得到旋转位置编码部分q和k的相似度,计算结果的shape为[bs, seq_len, n_heads, seq_len]
scores_pe = torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bs, :seq_len, :])
# #################################MLA的核心#################################
# 将两个部分的得分值加起来,然后再进行scale
scores = (scores_nope + scores_pe) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
if mask is not None:
scores += mask.unseqeeze(2)
scores = scores.softmax(dim=-1)
# k和v的相似度计算好了之后就要和v计算了,那v是由kv矩阵和wkv_b矩阵中的一部分计算得到的
# 先同kv矩阵计算,shape为[bs, seq_len, n_heads, kv_lora_rank]
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bs, :seq_len,:])
# 再同wkv_b[:, -self.v_head_dim:]计算,shape为[bs, seq_len, n_heads, v_head_dim]
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
x = x.contiguous().view(bs, seq_len, -1)
x = self.wo(x)
return x
if __name__ == '__main__':
torch.manual_seed(0)
torch.set_printoptions(precision=3, sci_mode=False)
x = torch.randn(1, 4, 16)
dim = 16
n_heads = 2
q_lora_rank = 10
kv_lora_rank = 6
qk_nope_head_dim = 8
qk_rope_head_dim = 4
v_head_dim = 8
max_seq_len = 10
max_batch_size = 4
mode = 'none'
mla = MLA(dim=dim,
n_heads=n_heads,
q_lora_rank=q_lora_rank,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size)
print(mla(x))
print(mla.kv_cache)
参考资料:
https://zhuanlan.zhihu.com/p/16730036197
llm_related/deepseek_learn at main · wyf3/llm_related · GitHub