大模型推理——MLA实现方案

1.整体流程

先上一张图来整体理解下MLA的计算过程

2.实现代码

import math
import torch
import torch.nn as nn


class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        # RMSNorm的参数g
        self.weight = nn.Parameter(torch.ones(hidden_size))
        # 防止分母为0
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        hidden_states = hidden_states.float()
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states*torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.float()


def rotate_half(x):
    x1,  x2 = x.chunk(2, dim=-1)
    return torch.cat((x1, x2), dim=-1)


def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    q_embed = (q*cos) + (rotate_half(q)*sin)
    k_embed = (k*cos) + (rotate_half(k)*cos)
    return q_embed, k_embed


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=1024):
        super(RotaryEmbedding, self).__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float()/dim))
        t= torch.arange(max_seq_len).float().unsqueeze(1)
        freqs = t @ inv_freq.unsqueeze(0)
        freqs = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", freqs.cos())
        self.register_buffer("sin_cached", freqs.sin())

    def forward(self, q, k):
        cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)
        sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)
        return apply_rotate_pos_emb(q, k, cos, sin)


class MLA(nn.Module):
    def __init__(self,
                 dim,
                 n_heads,
                 q_lora_rank,
                 kv_lora_rank,
                 qk_nope_head_dim,
                 qk_rope_head_dim,
                 v_head_dim,
                 max_seq_len,
                 max_batch_size):
        super().__init__()
        # 隐藏层维度
        self.dim = dim
        # attention head数
        self.n_heads = n_heads
        # q低秩压缩到的维度
        self.q_lora_rank = q_lora_rank
        # k/v低秩压缩到的维度
        self.kv_lora_rank = kv_lora_rank
        # q/k不带旋转位置编码的维度
        self.qk_nope_head_dim = qk_nope_head_dim
        # q/k带旋转位置编码的维度
        self.qk_rope_head_dim = qk_rope_head_dim
        # q/k的总维度
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        # v的维度
        self.v_head_dim = v_head_dim
        self.max_seq_len = max_seq_len
        self.max_batch_size = max_batch_size

        self.wq_a = nn.Linear(self.dim, self.q_lora_rank)
        self.q_norm = RMSNorm(self.q_lora_rank)
        self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads*self.qk_head_dim)
        self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads*(self.qk_nope_head_dim + self.v_head_dim))
        self.wo = nn.Linear(self.n_heads*self.v_head_dim, self.dim)

        self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)

        self.register_buffer("kv_cache", torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank))
        self.register_buffer("pe_cache", torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim))

    def forward(self, x, mask=None):
        bs, seq_len, _ = x.shape
        # [bs, seq_len, q_lora_rank]
        q = self.wq_a(x)
        # [bs, seq_len, q_lora_rank]
        q = self.q_norm(q)
        # [bs, seq_len, n_heads*(qk_nope_head_dim+qk_rope_head_dim)]
        q = self.wq_b(q)
        # [bs, seq_len, n_heads, (qk_nope_head_dim+qk_rope_head_dim)]
        q = q.view(bs, seq_len, self.n_heads,  self.qk_head_dim)
        # 按照最后一个维度进行切分
        #                                                                 --> [bs, seq_len, n_heads, qk_nope_head_dim]
        #                                                               --
        # [bs, seq_len, n_heads, (qk_nope_head_dim+qk_rope_head_dim)] --
        #                                                               --
        #                                                                 --> [bs, seq_len, n_heads, qk_rope_head_dim]
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        # [bs, seq_len, kv_lora_rank + qk_rope_head_dim]
        kv = self.wkv_a(x)
        # 按照最后一个维度进行切分
        #                                                    --> [bs, seq_len, kv_lora_rank]
        #                                                  --
        # [bs, seq_len, kv_lora_rank + qk_rope_head_dim] --
        #                                                  --
        #                                                    --> [bs, seq_len, qk_rope_head_dim]
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        # 和q的维度保持一致,[bs, seq_len, 1, qk_rope_head_dim]
        k_pe = k_pe.unsqueeze(2)
        # 旋转位置编码
        q_pe, k_pe = self.rotary_emb(q_pe, k_pe)

        # 重新压缩为原来的维度 [bs, seq_len, qk_rope_head_dim]
        k_pe = k_pe.squeeze(2)
        kv = self.kv_norm(kv)
        # 缓存共同作用于k和v的矩阵,该矩阵用于对k和v升维
        self.kv_cache[:bs, :seq_len, :] = kv
        # 缓存用于计算旋转位置编码部分的k矩阵
        self.pe_cache[:bs, :seq_len, :] = k_pe
        # [n_heads*(qk_nope_head_dim + v_head_dim), kv_lora_rank]
        wkv_b = self.wkv_b.weight
        # [n_heads, (qk_nope_head_dim + v_head_dim), kv_lora_rank]
        wkv_b = wkv_b.view(self.n_heads, -1, self.kv_lora_rank)
        # #################################MLA的核心#################################
        # q_nope可简单理解成x*w_q,然后再乘以w_k,即x*w_q*w_k,计算结果的shape为[bs, seq_len, n_heads, qk_nope_head_dim)
        q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
        # 再乘以k,这里的k是降维之后的x,即对x作用了一个降维矩阵wkv_a,计算结果的shape为[bs, seq_len, n_heads, seq_len]
        # 得到非旋转位置编码部分q和k的相似度
        scores_nope = torch.einsum("bshc, btc->bsht", q_nope, self.kv_cache[:bs, :seq_len, :])
        # 得到旋转位置编码部分q和k的相似度,计算结果的shape为[bs, seq_len, n_heads, seq_len]
        scores_pe = torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bs, :seq_len, :])
        # #################################MLA的核心#################################
        # 将两个部分的得分值加起来,然后再进行scale
        scores = (scores_nope + scores_pe) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
        if mask is not None:
            scores += mask.unseqeeze(2)

        scores = scores.softmax(dim=-1)
        # k和v的相似度计算好了之后就要和v计算了,那v是由kv矩阵和wkv_b矩阵中的一部分计算得到的
        # 先同kv矩阵计算,shape为[bs, seq_len, n_heads, kv_lora_rank]
        x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bs, :seq_len,:])
        # 再同wkv_b[:, -self.v_head_dim:]计算,shape为[bs, seq_len, n_heads, v_head_dim]
        x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])

        x = x.contiguous().view(bs, seq_len, -1)
        x = self.wo(x)

        return x


if __name__ == '__main__':
    torch.manual_seed(0)
    torch.set_printoptions(precision=3, sci_mode=False)

    x = torch.randn(1, 4, 16)

    dim = 16
    n_heads = 2
    q_lora_rank = 10
    kv_lora_rank = 6
    qk_nope_head_dim = 8
    qk_rope_head_dim = 4
    v_head_dim = 8
    max_seq_len = 10
    max_batch_size = 4
    mode = 'none'

    mla = MLA(dim=dim,
              n_heads=n_heads,
              q_lora_rank=q_lora_rank,
              kv_lora_rank=kv_lora_rank,
              qk_nope_head_dim=qk_nope_head_dim,
              qk_rope_head_dim=qk_rope_head_dim,
              v_head_dim=v_head_dim,
              max_seq_len=max_seq_len,
              max_batch_size=max_batch_size)

    print(mla(x))
    print(mla.kv_cache)

参考资料:

https://zhuanlan.zhihu.com/p/16730036197

llm_related/deepseek_learn at main · wyf3/llm_related · GitHub

### 关于 DeepSeek MLA 代码实例 对于希望了解如何利用 DeepSeek 进行机器学习加速(MLA)的应用场景,下面提供了一个基于 ESP32 设备与 DeepSeek API 结合使用的 Python 示例。此示例展示了如何设置 Wi-Fi 连接并发送 HTTP 请求给 DeepSeek 接口来获取实时响应[^2]。 ```python import network import urequests as requests from time import sleep # 配置Wi-Fi参数 wifi_ssid = "your_wifi_name" wifi_password = "your_wifi_password" def connect_to_wifi(): wlan = network.WLAN(network.STA_IF) wlan.active(True) if not wlan.isconnected(): print('Connecting to WiFi...') wlan.connect(wifi_ssid, wifi_password) while not wlan.isconnected(): pass print('Network config:', wlan.ifconfig()) connect_to_wifi() # 设置DeepSeek API URL 和 headers api_url = 'https://deepseek.example.com/api/v1/inference' headers = { 'Content-Type': 'application/json', 'Authorization': 'Bearer YOUR_API_KEY_HERE' # 替换成自己的API密钥 } # 发送请求至DeepSeek服务器 data = {"prompt": "你好,世界"} response = requests.post(api_url, json=data, headers=headers) if response.status_code == 200: result = response.json() print("Received reply:", result['text']) else: print(f"Failed with status {response.status_code}") sleep(1) # 延迟一秒以便观察输出效果 ``` 上述代码片段实现了基本的功能模块,包括但不限于 Wi-Fi 连接建立、HTTP 客户端库调用来执行 POST 方法向指定的 DeepSeek 终端提交数据以及接收来自服务端的消息回应。值得注意的是,在真实环境中还需要考虑异常处理机制以应对可能出现的各种状况。 为了进一步优化这段程序,可以引入更多的特性如用户输入接口设计或是增加错误重试逻辑等,从而让整个系统更加健壮可靠。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值