最新大模型架构TTT模型代码解析(二)

Class TTTBase

TTTBase 类通过方法和属性构建了一个灵活的模型层,可以根据不同配置调整其行为。这个类是构建更复杂模型的基础,如 TTTLinear 和 TTTMLP,它们继承自 TTTBase 并实现具体的 TTT 逻辑



class TTTBase(nn.Module):
    def __init__(self, config: TTTConfig, layer_idx: Optional[int] = None):
        """初始化TTTBase层。

        Args:
            config (TTTConfig): 模型配置对象。
            layer_idx (Optional[int]): 当前层的索引。默认值为None。
        """
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.width = config.hidden_size  # 设置隐藏层大小
        self.hidden_size = config.hidden_size  # 设置隐藏层大小
        self.num_heads = config.num_attention_heads  # 设置注意力头的数量
        self.head_dim = self.width // self.num_heads  # 计算每个头的维度
        self.mini_batch_size = config.mini_batch_size  # 设置微批处理大小

        # token_idx 是一个缩放因子,用于在公式4中进行求和
        token_idx = 1.0 / torch.arange(1, self.mini_batch_size + 1)
        self.register_buffer("token_idx", token_idx, persistent=False)
        # 使缩放因子可学习
        self.learnable_token_idx = nn.Parameter(torch.zeros((self.mini_batch_size,)))

        self.share_qk = config.share_qk  # 是否共享Q/K投影
        self.conv_kernel = config.conv_kernel  # 卷积核大小
        self._init_qkvo_proj()  # 初始化Q/K/V输出投影
        self._init_rope()  # 初始化旋转位置嵌入
        # 可学习的eta参数
        self._init_ttt_lr_gate()
        self._init_ttt_ln()

        # 是否使用门控机制
        self.use_gate = config.use_gate
        if self.use_gate:
            self.g_proj = nn.Linear(self.width, self.width, bias=False)

        self.post_norm = nn.LayerNorm(self.width, eps=1e-6)

    def _init_qkvo_proj(self):
        """初始化Q/K/V投影"""
        self.q_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
        if not self.share_qk:
            self.k_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)

        if self.share_qk:
            self.conv_q = nn.Conv1d(
                self.hidden_size,
                self.hidden_size,
                bias=True,
                kernel_size=self.conv_kernel,
                groups=self.hidden_size,
                padding=self.conv_kernel - 1,
            )
            self.conv_k = nn.Conv1d(
                self.hidden_size,
                self.hidden_size,
                bias=True,
                kernel_size=self.conv_kernel,
                groups=self.hidden_size,
                padding=self.conv_kernel - 1,
            )

    def _init_rope(self):
        """初始化旋转位置嵌入"""
        self.rope_theta = self.config.rope_theta
        self.rotary_emb = RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.mini_batch_size,
            base=self.rope_theta,
        )

    def _init_ttt_lr_gate(self):
        """初始化可学习的η参数"""
        linear_weight_data = nn.Linear(self.width, 1, bias=True).weight.data
        self.learnable_ttt_lr_weight = nn.Parameter(
            torch.stack(
                [torch.normal(0, 0.02, size=linear_weight_data.shape) for _ in range(self.num_heads)],
                dim=0,
            )
        )
        linear_bias_data = nn.Linear(self.width, 1, bias=True).bias.data
        self.learnable_ttt_lr_bias = nn.Parameter(
            torch.stack(
                [torch.zeros_like(linear_bias_data) for _ in range(self.num_heads)],
                dim=0,
            )
        )

    def _init_ttt_ln(self):
        """初始化可学习的层归一化参数"""
        ln_weight_data = nn.LayerNorm(self.head_dim).weight.data
        self.ttt_norm_weight = nn.Parameter(torch.tile(ln_weight_data.unsqueeze(0), (self.num_heads, 1)))
        ln_bias_data = nn.LayerNorm(self.head_dim).bias.data
        self.ttt_norm_bias = nn.Parameter(torch.tile(ln_bias_data.unsqueeze(0), (self.num_heads, 1)))

    def get_qkv_projections(self, hidden_states, cache_params: Optional[TTTCache] = None):
        """获取Q/K/V投影

        Args:
            hidden_states (torch.Tensor): 输入隐藏状态。
            cache_params (Optional[TTTCache]): 缓存参数。默认值为None。

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Q/K/V投影张量。
        """
        if self.share_qk:
            xq, XV = self.q_proj(hidden_states), self.v_proj(hidden_states)
            seq_len = xq.shape[1]
            xq = xq.transpose(1, 2)
            if causal_conv1d_fn is None:
                if cache_params is not None:
                    if cache_params.seqlen_offset > 0:
                        conv_q_state = cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx]
                        conv_q_state = torch.roll(conv_q_state, shifts=-1, dims=-1)
                        conv_q_state[:, :, -1] = xq[:, :, 0]
                        cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_state)
                        XQ = torch.sum(conv_q_state * self.conv_q.weight[:, 0, :], dim=-1)
                        XQ += self.conv_q.bias
                        XQ = XQ.unsqueeze(-1)

                        conv_k_state = cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx]
                        conv_k_state = torch.roll(conv_k_state, shifts=-1, dims=-1)
                        conv_k_state[:, :, -1] = xq[:, :, 0]
                        cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_state)
                        XK = torch.sum(conv_k_state * self.conv_k.weight[:, 0, :], dim=-1)
                        XK += self.conv_k.bias
                        XK = XK.unsqueeze(-1)
                    else:
                        conv_q_state = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
                        cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_state)
                        XQ = self.conv_q(xq)[..., :seq_len]
                        conv_k_state = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
                        cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_state)
                        XK = self.conv_k(xq)[..., :seq_len]
                else:
                    XQ = self.conv_q(xq)[..., :seq_len]
                    XK = self.conv_k(xq)[..., :seq_len]
            else:
                conv_q_weights = self.conv_q.weight.view(self.conv_q.weight.size(0), self.conv_q.weight.size(2))
                conv_k_weights = self.conv_k.weight.view(self.conv_k.weight.size(0), self.conv_k.weight.size(2))
                if cache_params is not None and cache_params.seqlen_offset > 0:
                    XQ = causal_conv1d_update(
                        xq.squeeze(-1),
                        cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx],
                        conv_q_weights,
                        self.conv_q.bias,
                        None,
                    )
                    XQ = XQ.unsqueeze(-1)
                    XK = causal_conv1d_update(
                        xq.squeeze(-1),
                        cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx],
                        conv_k_weights,
                        self.conv_k.bias,
                        None,
                    )
                    XK = XK.unsqueeze(-1)
                else:
                    if cache_params is not None:
                        conv_q_states = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
                        cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_states)
                        conv_k_states = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
                        cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_states)
                    XQ = causal_conv1d_fn(xq, conv_q_weights, self.conv_q.bias, activation=None)
                    XK = causal_conv1d_fn(xq, conv_k_weights, self.conv_k.bias, activation=None)

            XQ = XQ.transpose(1, 2)
            XK = XK.transpose(1, 2)
        else:
            XQ, XK, XV = (
                self.q_proj(hidden_states),
                self.k_proj(hidden_states),
                self.v_proj(hidden_states),
            )
        return XQ, XK, XV

    def _split_heads(self, hidden_states):
        """将隐藏状态分成多个头"""
        return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))

    def get_eta(self, X, mini_batch_step_offset, mini_batch_size):
        """计算可学习的η参数

        Args:
            X (torch.Tensor): 输入张量。
            mini_batch_step_offset (int): 微批处理步长偏移量。
            mini_batch_size (int): 微批处理大小。

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: token_eta和ttt_lr_eta张量。
        """
        ttt_lr = torch.einsum("bnkc,hdc->bhnkd", X, self.learnable_ttt_lr_weight) + self.learnable_ttt_lr_bias.reshape(
            1, -1, 1, 1, 1
        )
        ttt_lr = F.sigmoid(ttt_lr)

        ttt_lr = ttt_lr.permute(0, 1, 2, 4, 3)
        ttt_lr_eta = self.config.ttt_base_lr * ttt_lr / self.head_dim

        token_idx = self.token_idx + self.learnable_token_idx
        token_idx = token_idx[mini_batch_step_offset : mini_batch_step_offset + mini_batch_size]

        token_idx = torch.clamp_min(token_idx, 0.0)

        token_eta = torch.broadcast_to(
            token_idx.reshape(1, 1, 1, mini_batch_size, 1),
            (X.shape[0], self.num_heads, X.shape[1], mini_batch_size, 1),
        )

        return token_eta, ttt_lr_eta

    def apply_gate(self, hidden_states, ttt_output):
        """应用门控机制

        Args:
            hidden_states (torch.Tensor): 输入隐藏状态。
            ttt_output (torch.Tensor): TTT输出。

        Returns:
            torch.Tensor: 门控后的输出。
        """
        y = self.g_proj(hidden_states)
        y = F.gelu(y, approximate="tanh")
        output = y * ttt_output
        return output

    def get_ttt_inputs(self, inputs, mini_batch_size, cache_params):
        """获取TTT输入

        Args:
            inputs (dict): 输入数据。
            mini_batch_size (int): 微批处理大小。
            cache_params (Optional[TTTCache]): 缓存参数。默认值为None。

        Returns:
            dict: TTT输入数据。
        """
        XQ = inputs["XQ"]
        XK = inputs["XK"]
        XV = inputs["XV"]
        X = inputs["X"]
        B, L, C = X.shape
        num_mini_batch = L // mini_batch_size
        X = X.reshape(B, num_mini_batch, mini_batch_size, self.width)

        XQ = XQ.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim)
        XK = XK.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim)
        XV = XV.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim)

        if cache_params is not None:
            mini_batch_step_offset = cache_params.seqlen_offset % self.mini_batch_size
        else:
            mini_batch_step_offset = 0
        token_eta, ttt_lr_eta = self.get_eta(X, mini_batch_step_offset, mini_batch_size)
        eta = token_eta * ttt_lr_eta
        inputs = {
            "XQ": XQ,
            "XK": XK,
            "XV": XV,
            "eta": eta,
            "token_eta": token_eta,
            "ttt_lr_eta": ttt_lr_eta,
        }
        return inputs

    def ttt(
        self,
        inputs,
        mini_batch_size,
        last_mini_batch_params_dict,
        cache_params: Optional[TTTCache] = None,
    ):
        """TTT方法,必须在TTTBase的子类中实现"""
        raise NotImplementedError("ttt方法必须在TTTBase子类中实现。")

   def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        cache_params: Optional[TTTCache] = None,
    ):
        """前向传播函数

        Args:
            hidden_states (torch.Tensor): 输入隐藏状态张量。
            attention_mask (Optional[torch.Tensor]): 注意力掩码。默认值为None。
            position_ids (Optional[torch.LongTensor]): 位置ID。默认值为None。
            cache_params (Optional[TTTCache]): 缓存参数。默认值为None。

        Returns:
            torch.Tensor: 输出隐藏状态张量。
        """
        # 获取批处理大小和序列长度
        B, L = hidden_states.shape[:2]
        # 计算序列长度除以微批处理大小的余数
        reminder_len = L % self.mini_batch_size
        # 计算序列可以被微批处理大小整除的部分
        num_mini_batch = L // self.mini_batch_size
        last_mini_batch_params_dict = None

        # 获取Q/K/V投影
        XQ, XK, XV = self.get_qkv_projections(hidden_states, cache_params=cache_params)

        # 将Q/K/V重新形状化并进行转置
        XQ = XQ.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        XK = XK.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        XV = XV.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)

        # 获取旋转嵌入的余弦和正弦
        cos, sin = self.rotary_emb(XV, position_ids % self.mini_batch_size)

        # 应用旋转位置嵌入
        XQ, XK = permute_qk(XQ, XK)
        XQ, XK = apply_rotary_pos_emb(XQ, XK, cos, sin)
        XQ, XK = undo_permute_qk(XQ, XK)

        # 初始化输出隐藏状态列表
        output_hidden_states = []

        # 如果序列长度可以被微批处理大小整除
        if num_mini_batch > 0:
            # 构建输入字典
            inputs = {
                "XQ": XQ[:, :, : num_mini_batch * self.mini_batch_size],
                "XK": XK[:, :, : num_mini_batch * self.mini_batch_size],
                "XV": XV[:, :, : num_mini_batch * self.mini_batch_size],
                "X": hidden_states[:, : num_mini_batch * self.mini_batch_size],
            }
            # 获取TTT输入并调用TTT方法
            output_mod, last_mini_batch_params_dict = self.ttt(
                self.get_ttt_inputs(inputs, self.mini_batch_size, cache_params),
                mini_batch_size=self.mini_batch_size,
                last_mini_batch_params_dict=last_mini_batch_params_dict,
                cache_params=cache_params,
            )
            # 将输出添加到输出隐藏状态列表中
            output_hidden_states.append(output_mod)

        # 如果序列长度有余数部分
        if reminder_len > 0:
            # 构建输入字典
            inputs = {
                "XQ": XQ[:, :, -reminder_len:],
                "XK": XK[:, :, -reminder_len:],
                "XV": XV[:, :, -reminder_len:],
                "X": hidden_states[:, -reminder_len:],
            }
            # 获取TTT输入并调用TTT方法
            output_reminder, _ = self.ttt(
                self.get_ttt_inputs(inputs, reminder_len, cache_params),
                mini_batch_size=reminder_len,
                last_mini_batch_params_dict=last_mini_batch_params_dict,
                cache_params=cache_params,
            )
            # 将输出添加到输出隐藏状态列表中
            output_hidden_states.append(output_reminder)

        # 将所有输出隐藏状态拼接在一起
        output_hidden_states = torch.cat(output_hidden_states, dim=1)
        # 应用层归一化
        output_hidden_states = self.post_norm(output_hidden_states)
        # 如果使用门控机制
        if self.use_gate:
            output_hidden_states = self.apply_gate(hidden_states, output_hidden_states)
        # 应用输出投影
        output_hidden_states = self.o_proj(output_hidden_states)

        return output_hidden_states

TTTLinear 和 TTTMLP

TTTLinear 和 TTTMLP,它们继承自 TTTBase 并实现具体的 TTT 逻辑

class TTTLinear(TTTBase):
    def __init__(self, config: TTTConfig, layer_idx: Optional[int] = None):
        """初始化TTT-Linear层

        Args:
            config (TTTConfig): 模型配置对象。
            layer_idx (Optional[int]): 当前层的索引。默认值为None。
        """
        super().__init__(config, layer_idx)
        # 初始化TTT-Linear模型参数
        self.W1 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, self.head_dim, self.head_dim)))
        self.b1 = nn.Parameter(torch.zeros(self.num_heads, 1, self.head_dim))

    def ttt(
        self,
        inputs,
        mini_batch_size,
        last_mini_batch_params_dict,
        cache_params: Optional[TTTCache] = None,
    ):
        """实现TTT方法

        Args:
            inputs (dict): 输入数据。
            mini_batch_size (int): 微批处理大小。
            last_mini_batch_params_dict (dict): 上一个微批处理的参数字典。
            cache_params (Optional[TTTCache]): 缓存参数。默认值为None。

        Returns:
            Tuple[torch.Tensor, dict]: 输出张量和更新后的参数字典。
        """
        if mini_batch_size is None:
            mini_batch_size = self.mini_batch_size

        # 如果缓存参数存在但上一个微批处理的参数字典不存在,则获取缓存参数的字典
        if last_mini_batch_params_dict is None and cache_params is not None:
            last_mini_batch_params_dict = cache_params.ttt_params_to_dict(self.layer_idx)

        # 获取输入张量的形状参数
        B = inputs["XV"].shape[0]  # 批处理大小
        num_mini_batch = inputs["XV"].shape[2]  # 微批处理数量
        L = inputs["XV"].shape[2] * inputs["XV"].shape[3]  # 总序列长度
        device = inputs["XV"].device  # 设备
        dtype = inputs["XV"].dtype  # 数据类型

        # 判断是否使用对偶形式进行计算
        use_dual_form = cache_params is None or mini_batch_size % self.mini_batch_size == 0

        def compute_mini_batch(params_dict, inputs):
            """计算微批处理

            Args:
                params_dict (dict): 参数字典。
                inputs (dict): 输入数据。

            Returns:
                Tuple[dict, torch.Tensor]: 更新后的参数字典和输出张量。
            """
            # 获取初始参数
            W1_init = params_dict["W1_states"]
            b1_init = params_dict["b1_states"]

            # 获取微批处理输入
            XQ_mini_batch = inputs["XQ"]
            XV_mini_batch = inputs["XV"]
            XK_mini_batch = inputs["XK"]
            eta_mini_batch = inputs["eta"]
            token_eta_mini_batch = inputs["token_eta"]
            ttt_lr_eta_mini_batch = inputs["ttt_lr_eta"]

            # 计算线性变换
            X1 = XK_mini_batch
            Z1 = X1 @ W1_init + b1_init
            reconstruction_target = XV_mini_batch - XK_mini_batch

            # 获取层归一化参数
            ln_weight = self.ttt_norm_weight.reshape(self.num_heads, 1, self.head_dim)
            ln_bias = self.ttt_norm_bias.reshape(self.num_heads, 1, self.head_dim)

            # 计算梯度
            grad_l_wrt_Z1 = ln_fused_l2_bwd(Z1, reconstruction_target, ln_weight, ln_bias)

            if use_dual_form:
                # 对偶形式计算
                Attn1 = torch.tril(XQ_mini_batch @ X1.transpose(-2, -1))
                b1_bar = b1_init - torch.tril(eta_mini_batch) @ grad_l_wrt_Z1
                Z1_bar = XQ_mini_batch @ W1_init - (eta_mini_batch * Attn1) @ grad_l_wrt_Z1 + b1_bar

                last_eta_mini_batch = eta_mini_batch[:, :, -1, :, None]
                W1_last = W1_init - (last_eta_mini_batch * X1).transpose(-1, -2) @ grad_l_wrt_Z1
                b1_last = b1_init - torch.sum(last_eta_mini_batch * grad_l_wrt_Z1, dim=-2, keepdim=True)
                grad_W1_last = torch.zeros_like(W1_last)
                grad_b1_last = torch.zeros_like(b1_last)
            else:
                # 原始形式计算
                ttt_lr_eta_mini_batch = torch.broadcast_to(
                    ttt_lr_eta_mini_batch,
                    (
                        *ttt_lr_eta_mini_batch.shape[:2],
                        mini_batch_size,
                        mini_batch_size,
                    ),
                )

                grad_W1 = torch.einsum("bhki,bhkj->bhkij", X1, grad_l_wrt_Z1)
                grad_W1 = torch.einsum("bhnk,bhkij->bhnij", torch.tril(ttt_lr_eta_mini_batch), grad_W1)
                grad_W1 = grad_W1 + params_dict["W1_grad"].unsqueeze(2)
                grad_b1 = torch.einsum("bhnk,bhki->bhni", torch.tril(ttt_lr_eta_mini_batch), grad_l_wrt_Z1)
                grad_b1 = grad_b1 + params_dict["b1_grad"]

                W1_bar = W1_init.unsqueeze(2) - grad_W1 * token_eta_mini_batch.unsqueeze(-1)
                b1_bar = b1_init - grad_b1 * token_eta_mini_batch

                Z1_bar = (XQ_mini_batch.unsqueeze(3) @ W1_bar).squeeze(3) + b1_bar

                W1_last = W1_bar[:, :, -1]
                b1_last = b1_bar[:, :, -1:]
                grad_W1_last = grad_W1[:, :, -1]
                grad_b1_last = grad_b1[:, :, -1:]

            # 应用层归一化
            Z1_bar = ln_fwd(Z1_bar, ln_weight, ln_bias)

            # 更新Q投影
            XQW_mini_batch = XQ_mini_batch + Z1_bar

            last_param_dict = {
                "W1_states": W1_last,
                "b1_states": b1_last,
                "W1_grad": grad_W1_last,
                "b1_grad": grad_b1_last,
            }
            return last_param_dict, XQW_mini_batch

        # 初始化参数字典
        if last_mini_batch_params_dict is not None:
            init_params_dict = last_mini_batch_params_dict
        else:
            init_params_dict = {
                "W1_states": torch.tile(self.W1.unsqueeze(0), dims=(B, 1, 1, 1)),
                "b1_states": torch.tile(self.b1.unsqueeze(0), dims=(B, 1, 1, 1)),
            }
            init_params_dict.update(W1_grad=torch.zeros_like(init_params_dict["W1_states"]))
            init_params_dict.update(b1_grad=torch.zeros_like(init_params_dict["b1_states"]))

        # 转置输入数据
        inputs = tree_map(lambda x: x.permute(2, 0, 1, 3, 4), inputs)

        # 分配输出张量
        XQW_batch = torch.empty(
            (num_mini_batch, B, self.num_heads, mini_batch_size, self.head_dim),
            device=device,
            dtype=dtype,
        )

        # 扫描微批处理
        batch_params_dict, XQW_batch = scan(
            compute_mini_batch,
            init_params_dict,
            inputs,
            XQW_batch,
            self.config.scan_checkpoint_group_size if self.training else 0,
        )

        # 更新缓存参数
        if cache_params is not None:
            cache_params.update(batch_params_dict, self.layer_idx, L)

        # 转置并重新形状化输出
        XQW_batch = XQW_batch.permute(1, 0, 3, 2, 4)
        XQW_batch = XQW_batch.reshape(B, L, self.width)
        return XQW_batch, batch_params_dict
  1. __init__ 方法:

    • 初始化了TTT-Linear层,继承自 TTTBase
    • 初始化了线性变换的权重 W1 和偏置 b1
  2. ttt 方法:

    • 实现了TTT方法,用于处理微批处理的输入数据。
    • 判断是否使用对偶形式进行计算。
  3. compute_mini_batch 函数:

    • 计算微批处理的线性变换和梯度。
    • 根据是否使用对偶形式,分别进行计算。
    • 更新 W1 和 b1 的状态和梯度。
  4. 初始化参数字典:

    • 如果上一个微批处理的参数字典存在,则使用该字典进行初始化。
    • 否则,使用模型的初始参数进行初始化。
  5. 转置输入数据:

    • 将输入数据进行转置,以适应后续的计算流程。
  6. 分配输出张量:

    • 分配一个空的输出张量 XQW_batch
  7. 扫描微批处理:

    • 使用 scan 函数对微批处理进行扫描计算。
    • 更新 batch_params_dict 和 XQW_batch
  8. 更新缓存参数:

    • 如果缓存参数存在,则更新缓存参数。
  9. 转置并重新形状化输出:

    • 对输出进行转置并重新形状化。
class TTTMLP(TTTBase):
    def __init__(self, config: TTTConfig, layer_idx: Optional[int] = None):
        """初始化TTT-MLP层

        Args:
            config (TTTConfig): 模型配置对象。
            layer_idx (Optional[int]): 当前层的索引。默认值为None。
        """
        super().__init__(config, layer_idx)
        # 初始化TTT-MLP模型参数
        self.W1 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, self.head_dim, 4 * self.head_dim)))
        self.b1 = nn.Parameter(torch.zeros(self.num_heads, 1, 4 * self.head_dim))
        self.W2 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, 4 * self.head_dim, self.head_dim)))
        self.b2 = nn.Parameter(torch.zeros(self.num_heads, 1, self.head_dim))

    def ttt(
        self,
        inputs,
        mini_batch_size,
        last_mini_batch_params_dict,
        cache_params: Optional[TTTCache] = None,
    ):
        """实现TTT方法

        Args:
            inputs (dict): 输入数据。
            mini_batch_size (int): 微批处理大小。
            last_mini_batch_params_dict (dict): 上一个微批处理的参数字典。
            cache_params (Optional[TTTCache]): 缓存参数。默认值为None。

        Returns:
            Tuple[torch.Tensor, dict]: 输出张量和更新后的参数字典。
        """
        if mini_batch_size is None:
            mini_batch_size = self.mini_batch_size

        # 如果缓存参数存在但上一个微批处理的参数字典不存在,则获取缓存参数的字典
        if last_mini_batch_params_dict is None and cache_params is not None:
            last_mini_batch_params_dict = cache_params.ttt_params_to_dict(self.layer_idx)

        # 获取输入张量的形状参数
        B = inputs["XV"].shape[0]  # 批处理大小
        num_mini_batch = inputs["XV"].shape[2]  # 微批处理数量
        L = inputs["XV"].shape[2] * inputs["XV"].shape[3]  # 总序列长度
        device = inputs["XV"].device  # 设备
        dtype = inputs["XV"].dtype  # 数据类型

        # 判断是否使用对偶形式进行计算
        use_dual_form = cache_params is None or mini_batch_size % self.mini_batch_size == 0

        def compute_mini_batch(params_dict, inputs):
            """计算微批处理

            Args:
                params_dict (dict): 参数字典。
                inputs (dict): 输入数据。

            Returns:
                Tuple[dict, torch.Tensor]: 更新后的参数字典和输出张量。
            """
            # 获取初始参数
            W1_init = params_dict["W1_states"]
            b1_init = params_dict["b1_states"]
            W2_init = params_dict["W2_states"]
            b2_init = params_dict["b2_states"]

            # 获取微批处理输入
            XQ_mini_batch = inputs["XQ"]
            XV_mini_batch = inputs["XV"]
            XK_mini_batch = inputs["XK"]
            eta_mini_batch = inputs["eta"]
            token_eta_mini_batch = inputs["token_eta"]
            ttt_lr_eta_mini_batch = inputs["ttt_lr_eta"]

            # 计算线性变换和激活函数
            X1 = XK_mini_batch
            Z1 = X1 @ W1_init + b1_init
            X2 = F.gelu(Z1, approximate="tanh")
            Z2 = X2 @ W2_init + b2_init
            reconstruction_target = XV_mini_batch - XK_mini_batch

            # 获取层归一化参数
            ln_weight = self.ttt_norm_weight.reshape(self.num_heads, 1, self.head_dim)
            ln_bias = self.ttt_norm_bias.reshape(self.num_heads, 1, self.head_dim)

            # 计算梯度
            grad_l_wrt_Z2 = ln_fused_l2_bwd(Z2, reconstruction_target, ln_weight, ln_bias)
            grad_l_wrt_Z1 = grad_l_wrt_Z2 @ W2_init.transpose(-2, -1) * gelu_bwd(Z1)

            if use_dual_form:
                # 对偶形式计算
                Attn1 = torch.tril(XQ_mini_batch @ X1.transpose(-2, -1))
                b1_bar = b1_init - torch.tril(eta_mini_batch) @ grad_l_wrt_Z1
                Z1_bar = XQ_mini_batch @ W1_init - (eta_mini_batch * Attn1) @ grad_l_wrt_Z1 + b1_bar
                X2_bar = F.gelu(Z1_bar, approximate="tanh")

                Attn2 = torch.tril(X2_bar @ X2.transpose(-2, -1))
                b2_bar = b2_init - torch.tril(eta_mini_batch) @ grad_l_wrt_Z2
                Z2_bar = X2_bar @ W2_init - (eta_mini_batch * Attn2) @ grad_l_wrt_Z2 + b2_bar

                last_eta_mini_batch = eta_mini_batch[:, :, -1, :, None]
                W1_last = W1_init - (last_eta_mini_batch * X1).transpose(-1, -2) @ grad_l_wrt_Z1
                b1_last = b1_init - torch.sum(last_eta_mini_batch * grad_l_wrt_Z1, dim=-2, keepdim=True)
                W2_last = W2_init - (last_eta_mini_batch * X2).transpose(-1, -2) @ grad_l_wrt_Z2
                b2_last = b2_init - torch.sum(last_eta_mini_batch * grad_l_wrt_Z2, dim=-2, keepdim=True)
                grad_W1_last = torch.zeros_like(W1_last)
                grad_b1_last = torch.zeros_like(b1_last)
                grad_W2_last = torch.zeros_like(W2_last)
                grad_b2_last = torch.zeros_like(b2_last)

            else:
                # 原始形式计算
                ttt_lr_eta_mini_batch = torch.broadcast_to(
                    ttt_lr_eta_mini_batch,
                    (
                        *ttt_lr_eta_mini_batch.shape[:2],
                        mini_batch_size,
                        mini_batch_size,
                    ),
                )

                grad_W2 = torch.einsum("bhki,bhkj->bhkij", X2, grad_l_wrt_Z2)
                grad_W2 = torch.einsum("bhnk,bhkij->bhnij", torch.tril(ttt_lr_eta_mini_batch), grad_W2)
                grad_W2 = grad_W2 + params_dict["W2_grad"].unsqueeze(2)
                grad_b2 = torch.einsum("bhnk,bhki->bhni", torch.tril(ttt_lr_eta_mini_batch), grad_l_wrt_Z2)
                grad_b2 = grad_b2 + params_dict["b2_grad"]

                grad_W1 = torch.einsum("bhki,bhkj->bhkij", X1, grad_l_wrt_Z1)
                grad_W1 = torch.einsum("bhnk,bhkij->bhnij", torch.tril(ttt_lr_eta_mini_batch), grad_W1)
                grad_W1 = grad_W1 + params_dict["W1_grad"].unsqueeze(2)
                grad_b1 = torch.einsum("bhnk,bhki->bhni", torch.tril(ttt_lr_eta_mini_batch), grad_l_wrt_Z1)
                grad_b1 = grad_b1 + params_dict["b1_grad"]

                W1_bar = W1_init.unsqueeze(2) - grad_W1 * token_eta_mini_batch.unsqueeze(-1)
                b1_bar = b1_init - grad_b1 * token_eta_mini_batch
                W2_bar = W2_init.unsqueeze(2) - grad_W2 * token_eta_mini_batch.unsqueeze(-1)
                b2_bar = b2_init - grad_b2 * token_eta_mini_batch

                Z1_bar = (XQ_mini_batch.unsqueeze(3) @ W1_bar).squeeze(3) + b1_bar
                X2_bar = F.gelu(Z1_bar, approximate="tanh")
                Z2_bar = (X2_bar.unsqueeze(3) @ W2_bar).squeeze(3) + b2_bar

                W1_last = W1_bar[:, :, -1]
                b1_last = b1_bar[:, :, -1:]
                W2_last = W2_bar[:, :, -1]
                b2_last = b2_bar[:, :, -1:]
                grad_W1_last = grad_W1[:, :, -1]
                grad_b1_last = grad_b1[:, :, -1:]
                grad_W2_last = grad_W2[:, :, -1]
                grad_b2_last = grad_b2[:, :, -1:]

            # 应用层归一化
            Z2_bar = ln_fwd(Z2_bar, ln_weight, ln_bias)

            # 更新Q投影
            XQW_mini_batch = XQ_mini_batch + Z2_bar

            last_param_dict = {
                "W1_states": W1_last,
                "b1_states": b1_last,
                "W2_states": W2_last,
                "b2_states": b2_last,
                "W1_grad": grad_W1_last,
                "b1_grad": grad_b1_last,
                "W2_grad": grad_W2_last,
                "b2_grad": grad_b2_last,
            }
            return last_param_dict, XQW_mini_batch

        # 初始化参数字典
        if last_mini_batch_params_dict is not None:
            init_params_dict = last_mini_batch_params_dict
        else:
            init_params_dict = {
                "W1_states": torch.tile(self.W1.unsqueeze(0), dims=(B, 1, 1, 1)),
                "b1_states": torch.tile(self.b1.unsqueeze(0), dims=(B, 1, 1, 1)),
                "W2_states": torch.tile(self.W2.unsqueeze(0), dims=(B, 1, 1, 1)),
                "b2_states": torch.tile(self.b2.unsqueeze(0), dims=(B, 1, 1, 1)),
            }
            init_params_dict.update(W1_grad=torch.zeros_like(init_params_dict["W1_states"]))
            init_params_dict.update(b1_grad=torch.zeros_like(init_params_dict["b1_states"]))
            init_params_dict.update(W2_grad=torch.zeros_like(init_params_dict["W2_states"]))
            init_params_dict.update(b2_grad=torch.zeros_like(init_params_dict["b2_states"]))

        # 转置输入数据
        inputs = tree_map(lambda x: x.permute(2, 0, 1, 3, 4), inputs)

        # 分配输出张量
        XQW_batch = torch.empty(
            (num_mini_batch, B, self.num_heads, mini_batch_size, self.head_dim),
            device=device,
            dtype=dtype,
        )

        # 扫描微批处理
        batch_params_dict, XQW_batch = scan(
            compute_mini_batch,
            init_params_dict,
            inputs,
            XQW_batch,
            self.config.scan_checkpoint_group_size if self.training else 0,
        )

        # 更新缓存参数
        if cache_params is not None:
            cache_params.update(batch_params_dict, self.layer_idx, L)

        # 转置并重新形状化输出
        XQW_batch = XQW_batch.permute(1, 0, 3, 2, 4)
        XQW_batch = XQW_batch.reshape(B, L, self.width)
        return XQW_batch, batch_params_dict
  1. __init__ 方法:

    • 初始化了TTT-MLP层,继承自 TTTBase
    • 初始化了两层线性变换的权重 W1 和 W2 以及对应的偏置 b1 和 b2
  2. ttt 方法:

    • 实现了TTT方法,用于处理微批处理的输入数据。
    • 判断是否使用对偶形式进行计算。
  3. compute_mini_batch 函数:

    • 计算微批处理的线性变换、激活函数和梯度。
    • 根据是否使用对偶形式,分别进行计算。
    • 更新 W1W2 以及对应的偏置 b1 和 b2 的状态和梯度。
  4. 初始化参数字典:

    • 如果上一个微批处理的参数字典存在,则使用该字典进行初始化。
    • 否则,使用模型的初始参数进行初始化。
  5. 转置输入数据:

    • 将输入数据进行转置,以适应后续的计算流程。
  6. 分配输出张量:

    • 分配一个空的输出张量 XQW_batch
  7. 扫描微批处理:

    • 使用 scan 函数对微批处理进行扫描计算。
    • 更新 batch_params_dict 和 XQW_batch
  8. 更新缓存参数:

    • 如果缓存参数存在,则更新缓存参数。
  9. 转置并重新形状化输出:

               对输出进行转置并重新形状化。

Block:

    • Block 类是模型的一个组成部分,通常包含多个子层(如TTT层、MLP层等)。
    • __init__ 方法初始化了各种层(如TTT层、MLP层、卷积层、归一化层等),根据配置决定是否使用卷积层。
    • forward 方法实现了前向传播过程,包含了TTT层、归一化和MLP层的计算,并应用了残差连接。

TTTPreTrainedModel:

  • TTTPreTrainedModel 类是预训练模型的基类。
  • 包含模型的配置类和基础模型前缀信息,支持梯度检查点功能。
  • _init_weights 方法初始化模型的权重,主要针对线性层和嵌入层。
class Block(nn.Module):
    def __init__(self, config: TTTConfig, layer_idx: int):
        """初始化Block层

        Args:
            config (TTTConfig): 模型配置对象。
            layer_idx (int): 当前层的索引。
        """
        super().__init__()
        self.hidden_size = config.hidden_size  # 隐藏层大小
        self.pre_conv = config.pre_conv  # 是否使用卷积层

        # 根据配置选择TTT层类型
        if config.ttt_layer_type == "linear":
            ttt_layer = TTTLinear
        elif config.ttt_layer_type == "mlp":
            ttt_layer = TTTMLP
        else:
            raise ValueError(f"Invalid ttt_layer_type: {config.ttt_layer_type}")

        # 序列建模块
        self.seq_modeling_block = ttt_layer(config=config, layer_idx=layer_idx)

        # MLP层
        self.mlp = SwiGluMLP(config)

        # 如果启用卷积,则初始化卷积层
        if self.pre_conv:
            self.conv = Conv(config, layer_idx)

        # 归一化层
        self.seq_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.layer_idx = layer_idx

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        cache_params: Optional[TTTCache] = None,
    ) -> torch.Tensor:
        """前向传播函数

        Args:
            hidden_states (torch.Tensor): 输入隐藏状态张量。
            attention_mask (Optional[torch.Tensor]): 注意力掩码。
            position_ids (Optional[torch.LongTensor]): 位置ID。
            cache_params (Optional[TTTCache]): 缓存参数。

        Returns:
            torch.Tensor: 输出隐藏状态张量。
        """
        if self.pre_conv:
            residual = hidden_states  # 保存残差
            hidden_states = self.conv(hidden_states, cache_params=cache_params)  # 应用卷积层
            hidden_states = residual + hidden_states  # 残差连接

        residual = hidden_states  # 保存残差
        hidden_states = self.seq_norm(hidden_states)  # 应用归一化

        # TTT层
        hidden_states = self.seq_modeling_block(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            cache_params=cache_params,
        )
        hidden_states = residual + hidden_states  # 残差连接

        # 前馈网络
        residual = hidden_states  # 保存残差
        hidden_states = self.ffn_norm(hidden_states)  # 应用归一化
        hidden_states = self.mlp(hidden_states)  # 应用MLP层
        hidden_states = residual + hidden_states  # 残差连接

        return hidden_states


class TTTPreTrainedModel(PreTrainedModel):
    config_class = TTTConfig  # 配置类
    base_model_prefix = "model"  # 基础模型前缀
    supports_gradient_checkpointing = True  # 支持梯度检查点
    _no_split_modules = ["Block"]  # 不拆分的模块列表

    def _init_weights(self, module):
        """初始化权重

        Args:
            module (nn.Module): 需要初始化的模块。
        """
        std = self.config.initializer_range  # 初始化范围
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)  # 初始化线性层权重
            if module.bias is not None:
                module.bias.data.zero_()  # 初始化线性层偏置
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)  # 初始化嵌入层权重
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()  # 初始化嵌入层填充索引的权重

class TTTOutput和class TTTCausalLMOutput

@dataclass
class TTTOutput(ModelOutput):
    """
    TTT模型输出类。

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            模型最后一层输出的隐藏状态序列。
        cache_params (`TTTCache`):
            模型在最后一个时间步的状态。可以在下一个 `input_ids` 的前向方法中使用,以避免提供旧的 `input_ids`。
        hidden_states (`Tuple[torch.FloatTensor]`, optional):
            模型各层的隐藏状态序列(可选)。
    """

    last_hidden_state: Optional[torch.FloatTensor] = None  # 最后一层隐藏状态
    cache_params: Optional[TTTCache] = None  # 缓存参数
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None  # 各层隐藏状态序列(可选)


@dataclass
class TTTCausalLMOutput(ModelOutput):
    """
    因果语言模型(或自回归)输出的基类。

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, optional):
            语言模型损失(用于下一个token预测),仅在提供 `labels` 时返回。
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            语言建模头的预测分数(SoftMax之前每个词汇token的分数)。
        cache_params (`TTTCache`):
            模型在最后一个时间步的状态。可以在下一个 `input_ids` 的前向方法中使用,以避免提供旧的 `input_ids`。
        hidden_states (`Tuple[torch.FloatTensor]`, optional):
            模型各层的隐藏状态序列(可选)。
    """

    loss: Optional[torch.FloatTensor] = None  # 损失
    logits: Optional[torch.FloatTensor] = None  # 预测分数
    cache_params: Optional[TTTCache] = None  # 缓存参数
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None  # 各层隐藏状态序列(可选)

class TTTModel

  1. 基于 TTTPreTrainedModel 的解码器模型类,包含多个 Block 层。

forward:

  1. 校验输入参数确保不能同时指定 input_ids 和 inputs_embeds
  2. 检查梯度检查点与缓存的兼容性。
  3. 根据 input_ids 获取输入嵌入,如果未提供 inputs_embeds
  4. 初始化 cache_params,如果启用了缓存。
  5. 计算位置ID的偏移量。
  6. 创建注意力掩码,如果未提供。
  7. 通过多个解码层进行计算,每层包括梯度检查点。
  8. 规范化最后的隐藏状态。
  9. 返回结果,支持返回字典格式或元组格式。
  10. class TTTModel(TTTPreTrainedModel):
        """
        包含 *config.num_hidden_layers* 层的解码器。每一层都是一个 [`Block`]
    
        Args:
            config: TTTConfig
        """
    
        def __init__(self, config: TTTConfig):
            super().__init__(config)
            self.padding_idx = config.pad_token_id  # 填充索引
            self.vocab_size = config.vocab_size  # 词汇表大小
    
            # 嵌入层
            self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
            # 多层结构
            self.layers = nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
            # 归一化层
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            self.gradient_checkpointing = False  # 是否使用梯度检查点
    
            # 初始化权重并应用最终处理
            self.post_init()
    
        def get_input_embeddings(self):
            """获取输入嵌入层"""
            return self.embed_tokens
    
        def set_input_embeddings(self, value):
            """设置输入嵌入层"""
            self.embed_tokens = value
    
        def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            cache_params: Optional[TTTCache] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            use_cache: Optional[bool] = None,
        ) -> Union[Tuple, ModelOutput]:
            """前向传播函数
    
            Args:
                input_ids (torch.LongTensor, optional): 输入ID。
                attention_mask (Optional[torch.Tensor], optional): 注意力掩码。
                position_ids (Optional[torch.LongTensor], optional): 位置ID。
                inputs_embeds (Optional[torch.FloatTensor], optional): 输入嵌入。
                cache_params (Optional[TTTCache], optional): 缓存参数。
                output_hidden_states (Optional[bool], optional): 是否输出隐藏状态。
                return_dict (Optional[bool], optional): 是否返回字典。
                use_cache (Optional[bool], optional): 是否使用缓存。
    
            Returns:
                Union[Tuple, ModelOutput]: 输出结果。
            """
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
            use_cache = use_cache if use_cache is not None else self.config.use_cache
            return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    
            # 确保不能同时指定 input_ids 和 inputs_embeds,且必须指定其中一个
            if (input_ids is None) ^ (inputs_embeds is not None):
                raise ValueError(
                    "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
                )
    
            # 检查梯度检查点与缓存的兼容性
            if self.gradient_checkpointing and self.training and use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
                )
                use_cache = False
    
            # 如果未提供 inputs_embeds,则根据 input_ids 获取输入嵌入
            if inputs_embeds is None:
                inputs_embeds = self.embed_tokens(input_ids)
    
            # 如果未提供 cache_params 且使用缓存,则初始化 cache_params
            if cache_params is None and use_cache:
                cache_params = TTTCache(self, inputs_embeds.size(0))
    
            # 计算位置ID的偏移量
            seqlen_offset = 0
            if cache_params is not None:
                seqlen_offset = cache_params.seqlen_offset
            position_ids = torch.arange(
                seqlen_offset,
                seqlen_offset + inputs_embeds.shape[1],
                dtype=torch.long,
                device=inputs_embeds.device,
            ).unsqueeze(0)
    
            hidden_states = inputs_embeds
    
            # 如果未提供 attention_mask,则创建一个全为1的掩码
            if attention_mask is None:
                attention_mask = torch.ones_like(input_ids)
    
            # 解码层
            all_hidden_states = () if output_hidden_states else None
    
            for decoder_layer in self.layers:
                if self.gradient_checkpointing and self.training:
                    hidden_states = self._gradient_checkpointing_func(
                        decoder_layer.__call__,
                        hidden_states,
                        attention_mask,
                        position_ids,
                        cache_params,
                    )
                else:
                    hidden_states = decoder_layer(
                        hidden_states,
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                        cache_params=cache_params,
                    )
    
                if output_hidden_states:
                    all_hidden_states = all_hidden_states + (hidden_states,)
    
            if use_cache:
                cache_params.seqlen_offset += inputs_embeds.shape[1]
    
            hidden_states = self.norm(hidden_states)
    
            # 添加最后一个解码层的隐藏状态
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
    
            if not return_dict:
                return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
    
            return TTTOutput(
                last_hidden_state=hidden_states,
                cache_params=cache_params if use_cache else None,
                hidden_states=all_hidden_states,
            )

class TTTForCausalLM

class TTTForCausalLM(TTTPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]  # 绑定权重键

    def __init__(self, config):
        super().__init__(config)
        self.model = TTTModel(config)  # 初始化解码器模型
        self.vocab_size = config.vocab_size  # 词汇表大小
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)  # 语言模型头

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        """获取输出嵌入层"""
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        """设置输出嵌入层"""
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        """设置解码器"""
        self.model = decoder

    def get_decoder(self):
        """获取解码器"""
        return self.model

    def _update_model_kwargs_for_generation(
        self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
    ) -> Dict[str, Any]:
        """更新生成过程中的模型参数

        Args:
            outputs (ModelOutput): 模型输出。
            model_kwargs (Dict[str, Any]): 模型参数。

        Returns:
            Dict[str, Any]: 更新后的模型参数。
        """
        model_kwargs["cache_params"] = outputs.get("cache_params", None)
        # 更新注意力掩码
        if "attention_mask" in model_kwargs:
            attention_mask = model_kwargs["attention_mask"]
            model_kwargs["attention_mask"] = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
                dim=-1,
            )
        return model_kwargs

    def prepare_inputs_for_generation(
        self,
        input_ids,
        attention_mask=None,
        cache_params: Optional[TTTCache] = None,
        inputs_embeds=None,
        **kwargs,
    ):
        """准备生成过程中的输入

        Args:
            input_ids (torch.LongTensor): 输入ID。
            attention_mask (Optional[torch.Tensor], optional): 注意力掩码。
            cache_params (Optional[TTTCache], optional): 缓存参数。
            inputs_embeds (Optional[torch.FloatTensor], optional): 输入嵌入。

        Returns:
            Dict[str, Any]: 准备好的输入。
        """
        # 仅保留最后一个token的input_ids,如果状态传递
        if cache_params is not None:
            input_ids = input_ids[:, -1].unsqueeze(-1)
            attention_mask = attention_mask[:, -1].unsqueeze(-1) if attention_mask is not None else None

        if inputs_embeds is not None and cache_params is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "cache_params": cache_params,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )

        return model_inputs

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        cache_params: Optional[TTTCache] = None,
        labels: Optional[torch.LongTensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        use_cache: Optional[bool] = None,
        *,
        output_attentions: Optional[bool] = None,
    ) -> Union[Tuple, TTTCausalLMOutput]:
        """前向传播函数

        Args:
            input_ids (torch.LongTensor, optional): 输入ID。
            attention_mask (Optional[torch.Tensor], optional): 注意力掩码。
            position_ids (Optional[torch.LongTensor], optional): 位置ID。
            inputs_embeds (Optional[torch.FloatTensor], optional): 输入嵌入。
            cache_params (Optional[TTTCache], optional): 缓存参数。
            labels (Optional[torch.LongTensor], optional): 标签。
            output_hidden_states (Optional[bool], optional): 是否输出隐藏状态。
            return_dict (Optional[bool], optional): 是否返回字典。
            use_cache (Optional[bool], optional): 是否使用缓存。
            output_attentions (Optional[bool], optional): 是否输出注意力。

        Returns:
            Union[Tuple, TTTCausalLMOutput]: 输出结果。
        """
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        assert not output_attentions, "output_attentions is not available in TTTForCausalLM"

        # 解码器输出包括 (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            cache_params=cache_params,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            use_cache=use_cache,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # 移位以便 tokens < n 预测 n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # 展平 tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # 启用模型并行
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return TTTCausalLMOutput(
            loss=loss,
            logits=logits,
            cache_params=outputs.cache_params,
            hidden_states=outputs.hidden_states,
        )

用于因果语言建模(Causal Language Modeling)的类,它继承自 TTTPreTrainedModel。此类的主要作用是将一个预训练的 TTT 模型(TTTModel)封装成一个用于语言生成的模型

  1. 初始化 (__init__): 初始化模型,包括解码器模型 TTTModel 和语言模型头 lm_head,后者是一个线性层,用于将隐藏状态映射到词汇表大小的 logits 上。

  2. 输入嵌入 (get_input_embeddingsset_input_embeddings): 提供获取和设置模型输入嵌入层的方法。

  3. 输出嵌入 (get_output_embeddingsset_output_embeddings): 提供获取和设置模型输出嵌入层的方法,这通常用于调整模型的输出以适应不同的任务。

  4. 设置和获取解码器 (set_decoderget_decoder): 允许用户设置自定义的解码器,并能够获取当前使用的解码器。

  5. 更新生成参数 (_update_model_kwargs_for_generation): 在生成文本时,更新模型参数,如缓存参数和注意力掩码。

  6. 准备生成输入 (prepare_inputs_for_generation): 准备用于文本生成的输入,包括处理输入 ID、注意力掩码和缓存参数。

  7. 前向传播 (forward): 定义了模型的前向传播逻辑,包括处理输入 ID、注意力掩码、位置 ID、输入嵌入、缓存参数等,并计算损失和 logits。如果提供标签 labels,则计算交叉熵损失。最后,根据 return_dict 参数决定返回格式。

  8. 因果语言模型输出 (TTTCausalLMOutput): 定义了因果语言模型的输出格式,包括损失、logits、缓存参数和隐藏状态。

TTTForCausalLM 类通过这些方法和属性,实现了一个完整的因果语言模型,能够用于文本生成任务,如文本续写、问答等。它利用了 TTT 模型的高效特性,并添加了语言模型头来进行词汇级别的预测

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值