最新大模型架构TTT模型代码解析(一)

这项来自斯坦福大学、加州大学伯克利分校、加州大学圣迭戈分校和 Meta 的研究提出了一个新颖的序列建模方法,称为测试时训练(Test-Time Training, TTT)层。TTT 层通过用机器学习模型取代 RNN 的隐藏状态,并使用输入 token 的实际梯度下降来压缩上下文。研究表明,这种方法在性能上可以优于现有的 Transformer 和 Mamba 架构。

TTT 层的设计亮点包括:

1. 替代自注意力机制:TTT 层直接取代了 Transformer 中的自注意力层,采用了线性模型(TTT-Linear)和两层 MLP(TTT-MLP)作为隐藏状态。
   
2. 线性复杂度架构:通过表达性记忆解锁线性复杂性架构,能够在上下文中训练具有数百万甚至数十亿个 token 的大语言模型(LLM)。

3. 更高效的上下文利用:TTT 层通过更高效的上下文压缩和模型记忆机制,相较于 RNN 层能更好地处理长上下文。

4. 更快的实际运行时间:尽管自注意力机制的复杂度是线性的,研究表明 TTT 层在实际运行时间上更快。

实验结果表明,TTT-Linear 和 TTT-MLP 在较大的上下文长度(如 8k token)下明显优于 Mamba 和 Transformer。具体观察到的优点包括:

- TTT 层在长上下文中的表现优于 Mamba,更具表现力并且更高效。
- 在相同 FLOPs 成本下,TTT-MLP 的困惑度(perplexity)比 TTT-Linear 更低,但由于 FLOPs 的额外成本,TTT-Linear 在效率方面有优势。

研究团队还在 JAX 和 PyTorch 上提供了代码,用于模型训练和测试,可以在以下链接中获取:
- JAX 代码:[https://github.com/test-time-training/ttt-lm-jax](https://github.com/test-time-training/ttt-lm-jax)
- PyTorch 推理代码:[https://github.com/test-time-training/ttt-lm-pytorch](https://github.com/test-time-training/ttt-lm-pytorch)

接下来我将对模型前半部分代码进行解析

  1. 定义配置类 TTTConfig

  2. 定义基础模块如 rotate_halfpermute_qk 等

  3. 定义 RMSNormSwiGluMLPRotaryEmbedding 等组件类

  4. 定义 Conv 卷积层

  5. 定义 TTTCache 缓存机制

TTTConfig 类

定义了一个用于配置 TTT 模型的类 `TTTConfig`,还包含了一些辅助函数,用于处理查询和键张量的旋转位置嵌入。这些函数和类主要用于模型的配置和预处理步骤,就不详细给出实现了

辅助函数

rotate_half 函数

这个函数将输入张量 x 一分为二,然后交换这两部分的位置,并返回结果。

permute_qk 函数

这个函数对查询(q)和键(k)张量进行重新排序,以匹配 JAX 实现中的维度顺序。

undo_permute_qk 函数
def undo_permute_qk(q, k):
    bsz, num_head, seq_len, head_dim = q.shape
    q = q.reshape(bsz, num_head, seq_len, 2, head_dim // 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim)
    k = k.reshape(bsz, num_head, seq_len, 2, head_dim // 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim)
    return q, k

这个函数将 permute_qk 函数的操作逆转,恢复原始维度顺序。

apply_rotary_pos_emb 
函数
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

这个函数将旋转位置嵌入应用于查询(q)和键(k)张量。通过对 cossin 张量进行适当的扩展,以便它们可以正确地广播到 qk 的维度上。

RMSNorm类:

    • 主要用于对输入张量进行Root Mean Square Layer Normalization。
    • 初始化方法定义了权重和防止除零的小常数。
    • 前向传播方法对输入张量进行归一化处理,并乘以权重

​​​​​​​SwiGluMLP类:

    • 这是一个使用SiLU激活函数和Gated Linear Units (GLU)的多层感知机。
    • 初始化方法定义了相关的线性层和激活函数。
    • 前向传播方法根据配置中的pretraining_tp值,对输入张量进行切片和线性变换操作。
import torch
from torch import nn
import torch.nn.functional as F

# 激活函数字典,假设在其他地方定义
ACT2FN = {
    "silu": torch.nn.SiLU(),
    # 其他激活函数可以在这里添加
}

class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """初始化RMSNorm层。

        Args:
            hidden_size (int): 隐藏层维数。
            eps (floa
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值