最新大模型架构TTT模型代码解析(一)

这项来自斯坦福大学、加州大学伯克利分校、加州大学圣迭戈分校和 Meta 的研究提出了一个新颖的序列建模方法,称为测试时训练(Test-Time Training, TTT)层。TTT 层通过用机器学习模型取代 RNN 的隐藏状态,并使用输入 token 的实际梯度下降来压缩上下文。研究表明,这种方法在性能上可以优于现有的 Transformer 和 Mamba 架构。

TTT 层的设计亮点包括:

1. 替代自注意力机制:TTT 层直接取代了 Transformer 中的自注意力层,采用了线性模型(TTT-Linear)和两层 MLP(TTT-MLP)作为隐藏状态。
   
2. 线性复杂度架构:通过表达性记忆解锁线性复杂性架构,能够在上下文中训练具有数百万甚至数十亿个 token 的大语言模型(LLM)。

3. 更高效的上下文利用:TTT 层通过更高效的上下文压缩和模型记忆机制,相较于 RNN 层能更好地处理长上下文。

4. 更快的实际运行时间:尽管自注意力机制的复杂度是线性的,研究表明 TTT 层在实际运行时间上更快。

实验结果表明,TTT-Linear 和 TTT-MLP 在较大的上下文长度(如 8k token)下明显优于 Mamba 和 Transformer。具体观察到的优点包括:

- TTT 层在长上下文中的表现优于 Mamba,更具表现力并且更高效。
- 在相同 FLOPs 成本下,TTT-MLP 的困惑度(perplexity)比 TTT-Linear 更低,但由于 FLOPs 的额外成本,TTT-Linear 在效率方面有优势。

研究团队还在 JAX 和 PyTorch 上提供了代码,用于模型训练和测试,可以在以下链接中获取:
- JAX 代码:[https://github.com/test-time-training/ttt-lm-jax](https://github.com/test-time-training/ttt-lm-jax)
- PyTorch 推理代码:[https://github.com/test-time-training/ttt-lm-pytorch](https://github.com/test-time-training/ttt-lm-pytorch)

接下来我将对模型前半部分代码进行解析

  1. 定义配置类 TTTConfig

  2. 定义基础模块如 rotate_halfpermute_qk 等

  3. 定义 RMSNormSwiGluMLPRotaryEmbedding 等组件类

  4. 定义 Conv 卷积层

  5. 定义 TTTCache 缓存机制

TTTConfig 类

定义了一个用于配置 TTT 模型的类 `TTTConfig`,还包含了一些辅助函数,用于处理查询和键张量的旋转位置嵌入。这些函数和类主要用于模型的配置和预处理步骤,就不详细给出实现了

辅助函数

rotate_half 函数

这个函数将输入张量 x 一分为二,然后交换这两部分的位置,并返回结果。

permute_qk 函数

这个函数对查询(q)和键(k)张量进行重新排序,以匹配 JAX 实现中的维度顺序。

undo_permute_qk 函数
def undo_permute_qk(q, k):
    bsz, num_head, seq_len, head_dim = q.shape
    q = q.reshape(bsz, num_head, seq_len, 2, head_dim // 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim)
    k = k.reshape(bsz, num_head, seq_len, 2, head_dim // 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim)
    return q, k

这个函数将 permute_qk 函数的操作逆转,恢复原始维度顺序。

apply_rotary_pos_emb 
函数
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

这个函数将旋转位置嵌入应用于查询(q)和键(k)张量。通过对 cossin 张量进行适当的扩展,以便它们可以正确地广播到 qk 的维度上。

RMSNorm类:

    • 主要用于对输入张量进行Root Mean Square Layer Normalization。
    • 初始化方法定义了权重和防止除零的小常数。
    • 前向传播方法对输入张量进行归一化处理,并乘以权重

​​​​​​​SwiGluMLP类:

    • 这是一个使用SiLU激活函数和Gated Linear Units (GLU)的多层感知机。
    • 初始化方法定义了相关的线性层和激活函数。
    • 前向传播方法根据配置中的pretraining_tp值,对输入张量进行切片和线性变换操作。
import torch
from torch import nn
import torch.nn.functional as F

# 激活函数字典,假设在其他地方定义
ACT2FN = {
    "silu": torch.nn.SiLU(),
    # 其他激活函数可以在这里添加
}

class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """初始化RMSNorm层。

        Args:
            hidden_size (int): 隐藏层维数。
            eps (float): 防止除零的小常数。
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))  # 初始化权重参数为全1
        self.variance_epsilon = eps  # 设置epsilon值

    def forward(self, hidden_states):
        """前向传播函数。

        Args:
            hidden_states (torch.Tensor): 输入隐藏状态张量。

        Returns:
            torch.Tensor: 正则化后的隐藏状态张量。
        """
        input_dtype = hidden_states.dtype  # 保存输入张量的dtype
        hidden_states = hidden_states.to(torch.float32)  # 将输入转换为float32
        variance = hidden_states.pow(2).mean(-1, keepdim=True)  # 计算方差
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)  # 进行方差归一化
        return self.weight * hidden_states.to(input_dtype)  # 恢复原始dtype并乘以权重

class SwiGluMLP(nn.Module):
    def __init__(self, config):
        """初始化SwiGluMLP层。

        Args:
            config (TTTConfig): 模型配置对象。
        """
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size  # 从配置中获取隐藏层大小
        self.intermediate_size = config.intermediate_size  # 从配置中获取中间层大小
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)  # 定义线性层
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)  # 定义线性层
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)  # 定义线性层
        self.act_fn = ACT2FN[config.hidden_act]  # 从激活函数字典中获取激活函数

    def forward(self, x):
        """前向传播函数。

        Args:
            x (torch.Tensor): 输入张量。

        Returns:
            torch.Tensor: 输出张量。
        """
        if self.config.pretraining_tp > 1:
            # 如果预训练tp值大于1,进行切分操作
            slice = self.intermediate_size // self.config.pretraining_tp  # 计算切分大小
            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)  # 切分gate_proj权重
            up_proj_slices = self.up_proj.weight.split(slice, dim=0)  # 切分up_proj权重
            down_proj_slices = self.down_proj.weight.split(slice, dim=1)  # 切分down_proj权重

            # 分别对各个切片进行线性变换,然后拼接
            gate_proj = torch.cat(
                [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)],
                dim=-1,
            )
            up_proj = torch.cat(
                [F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)],
                dim=-1,
            )

            # 计算中间状态,并将其切分
            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)

            # 对每个切片进行线性变换,然后相加
            down_proj = [
                F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
            ]
            down_proj = sum(down_proj)  # 将所有切片相加
        else:
            # 如果预训练tp值不大于1,直接进行线性变换和激活函数计算
            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        return down_proj  # 返回输出张量

class RotaryEmbedding

import torch
from torch import nn

class RotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim,
        max_position_embeddings=16,
        base=10000,
        device=None,
        scaling_factor=1.0,
    ):
        """初始化RotaryEmbedding层。

        Args:
            dim (int): 位置嵌入的维度。
            max_position_embeddings (int, optional): 最大位置嵌入数。默认值为16。
            base (int, optional): 基数。默认值为10000。
            device (torch.device, optional): 设备。默认值为None。
            scaling_factor (float, optional): 缩放因子。默认值为1.0。
        """
        super().__init__()
        self.scaling_factor = scaling_factor  # 设置缩放因子
        self.dim = dim  # 位置嵌入的维度
        self.max_position_embeddings = max_position_embeddings  # 最大位置嵌入数
        self.base = base  # 基数
        # 计算逆频率
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        # 注册为buffer,以便在模型保存和加载时保持不变
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    @torch.no_grad()
    def forward(self, x, position_ids):
        """前向传播函数。

        Args:
            x (torch.Tensor): 输入张量,形状为[bs, num_attention_heads, seq_len, head_size]。
            position_ids (torch.Tensor): 位置ID张量。

        Returns:
            tuple: 返回余弦和正弦位置嵌入张量。
        """
        # 扩展逆频率张量的维度,以适应输入张量的形状
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        # 扩展位置ID张量的维度
        position_ids_expanded = position_ids[:, None, :].float()
        
        # 强制使用float32,以避免bfloat16在长上下文中失去精度
        # 参见https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        
        # 使用不自动混合精度的上下文进行计算
        with torch.autocast(device_type=device_type, enabled=False):
            # 计算频率嵌入
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            # 将频率嵌入沿最后一个维度拼接
            emb = torch.cat((freqs, freqs), dim=-1)
            # 计算余弦和正弦位置嵌入
            cos = emb.cos()
            sin = emb.sin()
        
        # 返回余弦和正弦位置嵌入张量,并将它们的dtype设置为输入张量的dtype
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

class Conv

import torch
from torch import nn
import torch.nn.functional as F

# 假设我们已经定义了RMSNorm类

class Conv(nn.Module):
    def __init__(self, config, layer_idx):
        """初始化Conv层。

        Args:
            config (TTTConfig): 模型配置对象。
            layer_idx (int): 当前层的索引。
        """
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx

        # 定义RMSNorm层
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # 定义一维卷积层
        self.conv = nn.Conv1d(
            config.hidden_size,  # 输入通道数
            config.hidden_size,  # 输出通道数
            bias=True,  # 是否使用偏置
            kernel_size=config.conv_kernel,  # 卷积核大小
            groups=config.hidden_size,  # 组数(组卷积)
            padding=config.conv_kernel - 1,  # 填充大小
        )

    def __call__(self, hidden_states, cache_params=None):
        """前向传播函数。

        Args:
            hidden_states (torch.Tensor): 输入隐藏状态张量。
            cache_params (dict, optional): 缓存参数。默认值为None。

        Returns:
            torch.Tensor: 输出隐藏状态张量。
        """
        seq_len = hidden_states.shape[1]
        hidden_states = self.norm(hidden_states)  # 归一化
        hidden_states = hidden_states.transpose(1, 2)  # 转换形状以适应Conv1d

        # 如果没有定义causal_conv1d_fn
        if causal_conv1d_fn is None:
            if cache_params is not None:
                if cache_params.seqlen_offset > 0:
                    # 处理缓存状态
                    conv_state = cache_params.conv_states_dic["pre_conv"][self.layer_idx]
                    conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
                    conv_state[:, :, -1] = hidden_states[:, :, 0]
                    cache_params.conv_states_dic["pre_conv"][self.layer_idx].copy_(conv_state)
                    hidden_states = torch.sum(conv_state * self.conv.weight[:, 0, :], dim=-1)
                    hidden_states += self.conv.bias
                    hidden_states = hidden_states.unsqueeze(-1)
                else:
                    conv_state = nn.functional.pad(
                        hidden_states,
                        (self.config.conv_kernel - hidden_states.shape[-1], 0),
                    )
                    cache_params.conv_states_dic["pre_conv"][self.layer_idx].copy_(conv_state)
                    hidden_states = self.conv(hidden_states)[..., :seq_len]
            else:
                hidden_states = self.conv(hidden_states)[..., :seq_len]
        else:
            conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
            if cache_params is not None and cache_params.seqlen_offset > 0:
                hidden_states = causal_conv1d_update(
                    hidden_states.squeeze(-1),
                    cache_params.conv_states_dic["pre_conv"][self.layer_idx],
                    conv_weights,
                    self.conv.bias,
                    None,
                )
                hidden_states = hidden_states.unsqueeze(-1)
            else:
                if cache_params is not None:
                    conv_states = nn.functional.pad(
                        hidden_states,
                        (self.config.conv_kernel - hidden_states.shape[-1], 0),
                    )
                    cache_params.conv_states_dic["pre_conv"][self.layer_idx].copy_(conv_states)
                hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv.bias, activation=None)

        hidden_states = hidden_states.transpose(1, 2)  # 转换形状以适应后续层

        return hidden_states


#########################
### TTT Layer Modules ###
#########################

def scan(f, init, xs, out, checkpoint_group=0):
    """模拟jax.lax.scan函数。

    Args:
        f (function): 扫描函数。
        init (any): 初始化值。
        xs (list or dict): 输入数据。
        out (list): 输出数据。
        checkpoint_group (int, optional): 检查点组的大小。默认值为0。

    Returns:
        any: 最终的carry值。
        list: 输出数据。
    """
    carry = init
    if isinstance(xs, dict):
        num_items = len(next(iter(xs.values())))
    else:
        num_items = len(xs[0])

    def scan_fn(carry, i_start, i_end):
        for i in range(i_start, i_end):
            if isinstance(xs, dict):
                x = {key: tensor[i] for key, tensor in xs.items()}
            else:
                x = [x[i] for x in xs]
            carry, y = f(carry, x)
            out[i] = y
        return carry

    if checkpoint_group > 0:
        ckpt_every_n = num_items // checkpoint_group
        for k in range(0, num_items, ckpt_every_n):
            carry = torch.utils.checkpoint.checkpoint(
                scan_fn, carry, k, min(k + ckpt_every_n, num_items), use_reentrant=False
            )
    else:
        carry = scan_fn(carry, 0, num_items)

    return carry, out


def ln_fwd(x, gamma, beta, eps=1e-6):
    """层归一化的前向传播。

    Args:
        x (torch.Tensor): 输入张量。
        gamma (torch.Tensor): 缩放参数。
        beta (torch.Tensor): 平移参数。
        eps (float, optional): 防止除零的小常数。默认值为1e-6。

    Returns:
        torch.Tensor: 归一化后的张量。
    """
    mu = x.mean(dim=-1, keepdim=True)  # 计算均值
    var = x.var(dim=-1, keepdim=True, unbiased=False)  # 计算方差

    std = torch.sqrt(var + eps)  # 计算标准差
    x_hat = (x - mu) / std  # 标准化

    y = gamma * x_hat + beta  # 缩放和平移

    return y


def ln_fused_l2_bwd(x, l2_target, gamma, beta, eps=1e-6):
    """融合L2损失的层归一化的反向传播。

    Args:
        x (torch.Tensor): 输入张量。
        l2_target (torch.Tensor): L2目标张量。
        gamma (torch.Tensor): 缩放参数。
        beta (torch.Tensor): 平移参数。
        eps (float, optional): 防止除零的小常数。默认值为1e-6。

    Returns:
        torch.Tensor: 梯度张量。
    """
    D = x.shape[-1]

    mu = x.mean(dim=-1, keepdim=True)  # 计算均值
    var = x.var(dim=-1, keepdim=True, unbiased=False)  # 计算方差

    std = torch.sqrt(var + eps)  # 计算标准差
    x_hat = (x - mu) / std  # 标准化

    y = gamma * x_hat + beta  # 缩放和平移

    grad_output = y - l2_target  # 计算输出的梯度
    grad_x_hat = grad_output * gamma  # 计算标准化后的梯度
    z = (
        (1.0 / D)
        * (
            D * grad_x_hat
            - grad_x_hat.sum(dim=-1, keepdim=True)
            - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True)
        )
        / std
    )  # 计算最终的梯度

    return z


# 修改自 https://github.com/NVIDIA/Megatron-LM/blob/e33c8f78a35765d5aa37475a144da60e8a2349d1/megatron/core/fusions/fused_bias_gelu.py#L26
def gelu_bwd(x):
    """GELU激活函数的反向传播。

    Args:
        x (torch.Tensor): 输入张量。

    Returns:
        torch.Tensor: 梯度张量。
    """
    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))  # 计算tanh输出
    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)  # 计算梯度
    return ff
  • Conv:

    • 初始化方法定义了RMSNorm层和一维卷积层。
    • 前向传播方法根据是否存在缓存参数及其状态,对输入张量进行归一化、卷积操作。
  • scan 函数:

    • 模拟了 jax.lax.scan 函数的行为,通过逐步应用函数 f 处理输入数据,并支持检查点机制以节省内存。
  • ln_fwd 函数:

    • 实现了层归一化的前向传播,计算均值、方差,并进行标准好的

Class TTTCache

import torch
from torch import nn
from collections import defaultdict
import logging

logger = logging.getLogger(__name__)

class TTTCache:
    """
    TTTCache 是一个用于存储 TTT 层的最后隐藏状态和梯度的数据结构。

    Arguments:
        model (TTTModel): 用于初始化缓存的模型。
        batch_size (int): 批处理大小。

    Attributes:
        seqlen_offset (int): 序列长度偏移量。
        mini_batch_size (int): 微批处理大小。
        params_dict (Dict[str, Dict[int, torch.Tensor]]): 以字典形式保存的参数,*_states, *_grad -> 层索引 -> [批处理大小, ...]
        conv_states_dic (Dict[str, Dict[int, torch.Tensor]]): 以字典形式保存的卷积状态,*_states -> 层索引 -> [批处理大小, ...]
    """

    def __init__(self, model, batch_size: int):
        config = model.config
        self.seqlen_offset = 0  # 初始化序列长度偏移量
        self.mini_batch_size = config.mini_batch_size  # 获取微批处理大小

        # 初始化TTT参数字典
        self.ttt_params_dict = defaultdict(dict)
        if "linear" in config.ttt_layer_type:
            # 如果TTT层类型是线性,则使用W1和b1参数
            self.ttt_param_names = ["W1", "b1"]
        elif "mlp" in config.ttt_layer_type:
            # 如果TTT层类型是MLP,则使用W1, b1, W2, b2参数
            self.ttt_param_names = ["W1", "b1", "W2", "b2"]
        else:
            raise ValueError(f"TTT层类型 {config.ttt_layer_type} 目前不支持")

        # 初始化卷积状态字典
        self.conv_states_dic = defaultdict(dict)
        logger.info(f"Creating cache of size: {batch_size}")
        for layer_idx in range(config.num_hidden_layers):
            for name in self.ttt_param_names:
                weight = getattr(model.layers[layer_idx].seq_modeling_block, name)
                # 将权重扩展至批处理大小
                tiled_weight = torch.tile(weight.unsqueeze(0), (batch_size,) + (1,) * weight.dim()).to(model.device)
                self.ttt_params_dict[f"{name}_states"][layer_idx] = tiled_weight
                # 对于解码,需要存储梯度
                self.ttt_params_dict[f"{name}_grad"][layer_idx] = torch.zeros_like(tiled_weight)

            if config.pre_conv:
                self.conv_states_dic["pre_conv"][layer_idx] = torch.zeros(
                    batch_size,
                    config.hidden_size,
                    config.conv_kernel,
                    device=model.device,
                )
            if config.share_qk:
                self.conv_states_dic["ttt_conv_q"][layer_idx] = torch.zeros(
                    batch_size,
                    config.hidden_size,
                    config.conv_kernel,
                    device=model.device,
                )
                self.conv_states_dic["ttt_conv_k"][layer_idx] = torch.zeros(
                    batch_size,
                    config.hidden_size,
                    config.conv_kernel,
                    device=model.device,
                )

    def update(self, py_tree, layer_idx, seq_len):
        """更新缓存中的参数。

        Args:
            py_tree (dict): 参数树。
            layer_idx (int): 层索引。
            seq_len (int): 序列长度。
        """
        if seq_len % self.mini_batch_size == 0:
            # 复制最后一个微批处理的状态,清空梯度
            for name in self.ttt_param_names:
                self.ttt_params_dict[f"{name}_states"][layer_idx].copy_(py_tree[f"{name}_states"])
                self.ttt_params_dict[f"{name}_grad"][layer_idx].zero_()
        elif seq_len < self.mini_batch_size:
            if seq_len != 1 and self.seqlen_offset > 0 and self.seqlen_offset % self.mini_batch_size != 0:
                raise ValueError("不支持分段更新")
            if (seq_len + self.seqlen_offset) % self.mini_batch_size == 0:
                # 复制最后一个微批处理的状态,清空梯度
                for name in self.ttt_param_names:
                    self.ttt_params_dict[f"{name}_states"][layer_idx].copy_(py_tree[f"{name}_states"])
                    self.ttt_params_dict[f"{name}_grad"][layer_idx].zero_()
            else:
                # 复制梯度以供下次更新使用
                for name in self.ttt_param_names:
                    self.ttt_params_dict[f"{name}_grad"][layer_idx].copy_(py_tree[f"{name}_grad"])
        else:
            raise ValueError(f"序列长度 {seq_len} 是不支持的部分更新")

    def ttt_params_to_dict(self, layer_idx):
        """将特定层的TTT参数转化为字典格式。

        Args:
            layer_idx (int): 层索引。

        Returns:
            dict: 包含特定层TTT参数的字典。
        """
        return {name: self.ttt_params_dict[name][layer_idx] for name in self.ttt_params_dict}
  • __init__ 方法:

    • 初始化了各种参数及其默认值,包括序列长度偏移量、微批处理大小等。
    • 根据配置中的TTT层类型(线性或MLP),初始化相应的参数字典。
    • 初始化卷积状态字典,并将权重扩展至批处理大小。
  • update 方法:

    • 更新缓存中的参数,如果序列长度是微批处理大小的整数倍,则复制状态并清空梯度。
    • 如果序列长度小于微批处理大小,检查偏移量是否支持分段更新,并进行相应操作。
    • 如果序列长度不满足以上条件,抛出错误。
  • ttt_params_to_dict 方法:

    • 将特定层的TTT参数转化为字典格式,便于后续处理。

  • 4
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值