万字长文,代码详解Memory3:革命性RAG模型如何重新定义大规模语言模型

我是芝士AI吃鱼,原创 NLP、LLM、超长文知识分享
热爱分享前沿技术知识,寻找志同道合小伙伴
公众号 :芝士AI吃鱼

image.png

1、引言

大型语言模型(LLMs)在近年来取得了巨大的成功,展现出惊人的能力。然而,随着模型规模的不断增大,LLMs的训练和推理成本也在急剧上升。如何在保持或提升性能的同时降低成本,成为了当前LLM研究的一个重要方向。
在这篇技术博客中,我们将详细介绍一种名为Memory3的创新模型,它通过引入显式记忆机制来优化知识存储,从而大幅提高模型效率。Memory3的核心思想是:

  1. 将部分知识从模型参数外化到显式记忆中,降低模型参数量和训练成本。
  2. 设计高效的显式记忆读写机制,在推理时动态调用所需知识,避免知识遍历问题。

Memory3模型的主要贡献包括:

  1. 提出了记忆电路理论,为知识外化提供了理论支持。
  2. 设计了高效的显式记忆机制,包括记忆稀疏化、并行位置编码等技术。
  3. 提出了两阶段预训练方案,有效促进记忆形成。
  4. 在多项任务上超越了更大规模的模型,同时保持较快的推理速度。

image.png

2、理论基础

Memory3模型的核心创新在于引入显式记忆机制,为此,研究团队提出了一套完整的理论框架,包括知识和记忆的定义、记忆电路理论、以及可分离知识和可模仿知识的概念。这些理论为知识外化和显式记忆机制提供了坚实的基础。

2.1 知识和记忆的定义

在Memory3的理论框架中,知识被定义为LLM计算图中的一个电路。具体来说:

  1. 计算图:
    • 节点:所有注意力层和MLP层的隐藏向量
    • 边:这些层内的所有激活函数
  2. 电路:
    • 计算图中同态子图的等价类
    • 具有非可忽略边权重
    • 具有可解释的输入-输出关系
  3. 知识:
    • 特定知识:输入具有可解释含义,输出基本固定
    • 抽象知识:其他情况

这种定义将知识与LLM的内部计算机制直接关联,为后续的知识外化奠定了基础。

2.2 记忆电路理论

记忆电路理论是Memory3模型的核心理论基础,它定义了不同类型的记忆及其特性:

  1. 隐式记忆(模型参数):
    • 写入成本高,读取成本低
    • 适合存储频繁使用的知识
  2. 显式记忆(Memory3引入):
    • 写入和读取成本适中
    • 适合存储使用频率中等的知识
  3. 外部信息(RAG中的文本检索):
    • 写入成本低,读取成本高
    • 适合存储很少使用的知识

这种记忆层次结构类似于人脑的记忆机制,为LLM提供了更灵活和高效的知识存储方案。

2.3 可分离知识和可模仿知识

为了确定哪些知识可以外化到显式记忆中,研究团队引入了可分离知识和可模仿知识的概念:

  1. 可分离知识:
    • 定义:存在另一个LLM M,在没有该知识时无法高概率生成正确输出,但在给定特定前缀后可以高概率生成正确输出
    • 特点:可以通过检索示例或抽象描述来激活
  2. 可模仿知识:
    • 定义:任何该知识的实现都可以作为激活前缀
    • 特点:是可分离知识的一个子集

研究发现,所有特定知识都是可模仿的,因此可以被外化到显式记忆中。这一发现为Memory3模型的设计提供了理论依据。
image.png

3、Memory3模型架构

基于前面介绍的理论基础,Memory3模型设计了一套创新的架构,其核心是显式记忆机制。这一章节将详细介绍Memory3的模型结构、显式记忆机制的实现,以及记忆稀疏化和存储方法。

3.1 显式记忆机制

Memory3的显式记忆机制设计目标是实现适中的写入和读取成本,同时尽可能减少对现有Transformer架构的修改。其主要特点包括:

  1. 写入过程:
    • 在推理前,将每个参考文本转换为显式记忆
    • 显式记忆是从自注意力层的key-value向量中选择得到
    • 每个参考文本独立处理,避免长上下文注意力计算
  2. 读取过程:
    • 在推理时,从存储设备加载检索到的显式记忆
    • 将显式记忆与常规上下文key-value向量连接,通过自注意力层读取
    • 每个记忆只包含少量key-value,大幅减少额外计算和存储需求
  3. 检索频率:
    • 每生成64个token,丢弃当前记忆,检索5个新记忆
    • 处理prompt时,每64个token检索5个记忆
  4. 检索方法:
    • 使用BGE-M3多语言BERT模型进行向量嵌入
    • 采用FAISS进行向量索引和检索
def memory_retrieval(query_chunk):
    # 使用BGE-M3模型进行向量嵌入
    query_embedding = bge_m3_model.encode(query_chunk)
    
    # 使用FAISS检索最相关的5个记忆
    _, memory_ids = faiss_index.search(query_embedding, k=5)
    
    # 从存储设备加载显式记忆
    explicit_memories = load_memories(memory_ids)
    
    return explicit_memories

def memory_augmented_generation(input_text):
    tokens = tokenize(input_text)
    generated_tokens = []
    
    for i in range(0, len(tokens), 64):
        chunk = tokens[i:i+64]
        memories = memory_retrieval(chunk)
        
        # 将显式记忆与上下文连接,进行生成
        output = generate_with_memories(chunk, memories)
        generated_tokens.extend(output)
    
    return detokenize(generated_tokens)

每64个token进行一次记忆检索,然后将检索到的显式记忆与当前上下文结合进行生成。

3.2 模型结构

Memory3模型的基本结构仍然是Transformer,但在自注意力机制上进行了修改以支持显式记忆。主要特点包括:

  1. 参数配置:
    • Transformer块数: L = 44
    • 查询头数: H = 40
    • Key-Value头数: H_kv = 8 (使用分组查询注意力, GQA)
    • 头维度: d_h = 80
    • 隐藏维度: d = H * d_h = 3200
    • MLP宽度: W = d = 3200
    • 词汇表大小: n_vocab = 60416
    • 记忆层数: L_mem = 22 (前半部分层为记忆层)
  2. 注意力计算:
    对于每个记忆头h在层l,其输出Y^l,h计算如下:

Y l , h i = softmax ( ( X l , h i W l , h Q ) ⋅ concat ( K l , h 0 , … , K l , h 4 , X l , h [ : i ] W l , h K ) T d h ) ⋅ concat ( V l , h 0 , … , V l , h 4 , X l , h [ : i ] W l , h V ) W l , h O Y^{l,h_i} = \text{softmax}\left(\frac{(X^{l,h_i} W^{l,h_Q}) \cdot \text{concat}(K^{l,h_0}, \ldots, K^{l,h_4}, X^{l,h_{[:i]}} W^{l,h_K})^T}{\sqrt{d_h}}\right) \cdot \text{concat}(V^{l,h_0}, \ldots, V^{l,h_4}, X^{l,h_{[:i]}} W^{l,h_V}) W^{l,h_O} Yl,hi=softmax(dh (Xl,hiWl,hQ)concat(Kl,h0,,Kl,h4,Xl,h[:i]Wl,hK)T)concat(Vl,h0,,Vl,h4,Xl,h[:i]Wl,hV)Wl,hO
其中Kl,h_j和Vl,h_j是显式记忆的key和value。

  1. 位置编码:
    • 采用旋转位置编码(RoPE)
    • 所有显式记忆使用并行位置编码,位置都在0-127范围内
  2. 优化设计:
    • 仅在前半部分层使用显式记忆
    • 使用分组查询注意力(GQA)减少key-value头数
    • 对每个记忆头只选择8个token参与注意力计算
  3. 记忆整合:
    为了更好地整合显式记忆和常规上下文,Memory3引入了一个特殊的BOS token:
class Memory3Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_layers)])
        self.ln_f = nn.LayerNorm(config.hidden_size)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # 特殊的BOS token嵌入
        self.reference_bos = nn.Parameter(torch.randn(config.hidden_size))
    
    def forward(self, input_ids, attention_mask, memories=None):
        x = self.embed(input_ids)
        
        # 插入特殊的Reference BOS
        if memories is not None:
            x = torch.cat([self.reference_bos.unsqueeze(0).unsqueeze(0), x], dim=1)
            attention_mask = torch.cat([torch.ones(attention_mask.shape[0], 1, device=attention_mask.device), attention_mask], dim=1)
        
        for i, layer in enumerate(self.layers):
            x = layer(x, attention_mask, memories if i < self.config.num_memory_layers else None)
        
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        return logits
  1. 并行位置编码:
    Memory3使用旋转位置编码(RoPE),并为所有显式记忆应用并行位置编码:
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000):
        super().__init__()
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device).type_as(self.inv_freq)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer('cos_cached', emb.cos()[None, None, :, :])
        self.register_buffer('sin_cached', emb.sin()[None, None, :, :])

    def forward(self, x, seq_len=None):
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len)
        return (
            self.cos_cached[:, :, :seq_len, ...],
            self.sin_cached[:, :, :seq_len, ...]
        )

def apply_rotary_pos_emb(q, k, cos, sin):
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

class TransformerLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        # ... 其他初始化代码 ...
        self.rotary_emb = RotaryEmbedding(config.head_dim)
    
    def forward(self, hidden_states, attention_mask, memories=None):
        # ... 其他前向传播代码 ...
        
        # 应用RoPE
        q, k = self.rotary_emb(q, k)
        
        if memories is not None:
            # 为显式记忆应用并行位置编码
            mem_pos = torch.arange(128, device=q.device)
            mem_cos, mem_sin = self.rotary_emb(mem_pos)
            for mem in memories:
                mem.k, mem.v = apply_rotary_pos_emb(mem.k, mem.v, mem_cos, mem_sin)
        
        # ... 继续注意力计算 ...

3.3 记忆稀疏化和存储

为了解决显式记忆占用空间过大的问题,Memory3采用了多维度的稀疏化策略:

  1. 层维度:
    • 只在前22层(共44层)使用显式记忆
  2. 头维度:
    • 使用分组查询注意力,将key-value头数减少到8个
  3. token维度:
    • 每个key-value头只选择8个最重要的token
    • 选择标准:基于无mask和位置编码的注意力权重
  4. 向量维度:
    • 可选使用向量量化器进行压缩
    • 压缩率约为11.4倍

通过这些稀疏化策略,Memory3将显式记忆的存储需求从7.17PB压缩到了45.9TB(不使用向量压缩)或4.02TB(使用向量压缩)。
image.png

  1. 稀疏化实现:
    以下代码展示了如何实现token维度的稀疏化:
def sparsify_memory(memory, top_k=8):
    # 计算注意力权重
    attn_weights = torch.einsum('bhid,bhjd->bhij', memory.q, memory.k.transpose(2, 3)) / math.sqrt(memory.q.size(-1))
    attn_weights = attn_weights.softmax(dim=-1)
    
    # 选择top-k的token
    _, top_indices = torch.topk(attn_weights.sum(dim=(0, 1)), k=top_k, dim=-1)
    
    # 稀疏化memory
    memory.k = memory.k[:, :, top_indices, :]
    memory.v = memory.v[:, :, top_indices, :]
    
    return memory

class Memory3Model(nn.Module):
    # ... 其他代码 ...
    
    def retrieve_and_sparsify_memories(self, query):
        memories = self.retrieve_memories(query)
        return [sparsify_memory(mem) for mem in memories]
  1. 向量压缩:
    使用FAISS库实现向量量化压缩:
import faiss

class VectorCompressor:
    def __init__
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

芝士AI吃鱼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值