Class TTTBase
TTTBase 类通过方法和属性构建了一个灵活的模型层,可以根据不同配置调整其行为。这个类是构建更复杂模型的基础,如 TTTLinear 和 TTTMLP,它们继承自 TTTBase 并实现具体的 TTT 逻辑
class TTTBase(nn.Module):
def __init__(self, config: TTTConfig, layer_idx: Optional[int] = None):
"""初始化TTTBase层。
Args:
config (TTTConfig): 模型配置对象。
layer_idx (Optional[int]): 当前层的索引。默认值为None。
"""
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.width = config.hidden_size # 设置隐藏层大小
self.hidden_size = config.hidden_size # 设置隐藏层大小
self.num_heads = config.num_attention_heads # 设置注意力头的数量
self.head_dim = self.width // self.num_heads # 计算每个头的维度
self.mini_batch_size = config.mini_batch_size # 设置微批处理大小
# token_idx 是一个缩放因子,用于在公式4中进行求和
token_idx = 1.0 / torch.arange(1, self.mini_batch_size + 1)
self.register_buffer("token_idx", token_idx, persistent=False)
# 使缩放因子可学习
self.learnable_token_idx = nn.Parameter(torch.zeros((self.mini_batch_size,)))
self.share_qk = config.share_qk # 是否共享Q/K投影
self.conv_kernel = config.conv_kernel # 卷积核大小
self._init_qkvo_proj() # 初始化Q/K/V输出投影
self._init_rope() # 初始化旋转位置嵌入
# 可学习的eta参数
self._init_ttt_lr_gate()
self._init_ttt_ln()
# 是否使用门控机制
self.use_gate = config.use_gate
if self.use_gate:
self.g_proj = nn.Linear(self.width, self.width, bias=False)
self.post_norm = nn.LayerNorm(self.width, eps=1e-6)
def _init_qkvo_proj(self):
"""初始化Q/K/V投影"""
self.q_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
if not self.share_qk:
self.k_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
if self.share_qk:
self.conv_q = nn.Conv1d(
self.hidden_size,
self.hidden_size,
bias=True,
kernel_size=self.conv_kernel,
groups=self.hidden_size,
padding=self.conv_kernel - 1,
)
self.conv_k = nn.Conv1d(
self.hidden_size,
self.hidden_size,
bias=True,
kernel_size=self.conv_kernel,
groups=self.hidden_size,
padding=self.conv_kernel - 1,
)
def _init_rope(self):
"""初始化旋转位置嵌入"""
self.rope_theta = self.config.rope_theta
self.rotary_emb = RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.mini_batch_size,
base=self.rope_theta,
)
def _init_ttt_lr_gate(self):
"""初始化可学习的η参数"""
linear_weight_data = nn.Linear(self.width, 1, bias=True).weight.data
self.learnable_ttt_lr_weight = nn.Parameter(
torch.stack(
[torch.normal(0, 0.02, size=linear_weight_data.shape) for _ in range(self.num_heads)],
dim=0,
)
)
linear_bias_data = nn.Linear(self.width, 1, bias=True).bias.data
self.learnable_ttt_lr_bias = nn.Parameter(
torch.stack(
[torch.zeros_like(linear_bias_data) for _ in range(self.num_heads)],
dim=0,
)
)
def _init_ttt_ln(self):
"""初始化可学习的层归一化参数"""
ln_weight_data = nn.LayerNorm(self.head_dim).weight.data
self.ttt_norm_weight = nn.Parameter(torch.tile(ln_weight_data.unsqueeze(0), (self.num_heads, 1)))
ln_bias_data = nn.LayerNorm(self.head_dim).bias.data
self.ttt_norm_bias = nn.Parameter(torch.tile(ln_bias_data.unsqueeze(0), (self.num_heads, 1)))
def get_qkv_projections(self, hidden_states, cache_params: Optional[TTTCache] = None):
"""获取Q/K/V投影
Args:
hidden_states (torch.Tensor): 输入隐藏状态。
cache_params (Optional[TTTCache]): 缓存参数。默认值为None。
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Q/K/V投影张量。
"""
if self.share_qk:
xq, XV = self.q_proj(hidden_states), self.v_proj(hidden_states)
seq_len = xq.shape[1]
xq = xq.transpose(1, 2)
if causal_conv1d_fn is None:
if cache_params is not None:
if cache_params.seqlen_offset > 0:
conv_q_state = cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx]
conv_q_state = torch.roll(conv_q_state, shifts=-1, dims=-1)
conv_q_state[:, :, -1] = xq[:, :, 0]
cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_state)
XQ = torch.sum(conv_q_state * self.conv_q.weight[:, 0, :], dim=-1)
XQ += self.conv_q.bias
XQ = XQ.unsqueeze(-1)
conv_k_state = cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx]
conv_k_state = torch.roll(conv_k_state, shifts=-1, dims=-1)
conv_k_state[:, :, -1] = xq[:, :, 0]
cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_state)
XK = torch.sum(conv_k_state * self.conv_k.weight[:, 0, :], dim=-1)
XK += self.conv_k.bias
XK = XK.unsqueeze(-1)
else:
conv_q_state = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_state)
XQ = self.conv_q(xq)[..., :seq_len]
conv_k_state = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_state)
XK = self.conv_k(xq)[..., :seq_len]
else:
XQ = self.conv_q(xq)[..., :seq_len]
XK = self.conv_k(xq)[..., :seq_len]
else:
conv_q_weights = self.conv_q.weight.view(self.conv_q.weight.size(0), self.conv_q.weight.size(2))
conv_k_weights = self.conv_k.weight.view(self.conv_k.weight.size(0), self.conv_k.weight.size(2))
if cache_params is not None and cache_params.seqlen_offset > 0:
XQ = causal_conv1d_update(
xq.squeeze(-1),
cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx],
conv_q_weights,
self.conv_q.bias,
None,
)
XQ = XQ.unsqueeze(-1)
XK = causal_conv1d_update(
xq.squeeze(-1),
cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx],
conv_k_weights,
self.conv_k.bias,
None,
)
XK = XK.unsqueeze(-1)
else:
if cache_params is not None:
conv_q_states = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_states)
conv_k_states = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_states)
XQ = causal_conv1d_fn(xq, conv_q_weights, self.conv_q.bias, activation=None)
XK = causal_conv1d_fn(xq, conv_k_weights, self.conv_k.bias, activation=None)
XQ = XQ.transpose(1, 2)
XK = XK.transpose(1, 2)
else:
XQ, XK, XV = (
self.q_proj(hidden_states),
self.k_proj(hidden_states),
self.v_proj(hidden_states),
)
return XQ, XK, XV
def _split_heads(self, hidden_states):
"""将隐藏状态分成多个头"""
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
def get_eta(self, X, mini_batch_step_offset, mini_batch_size):
"""计算可学习的η参数
Args:
X (torch.Tensor): 输入张量。
mini_batch_step_offset (int): 微批处理步长偏移量。
mini_batch_size (int): 微批处理大小。
Returns:
Tuple[torch.Tensor, torch.Tensor]: token_eta和ttt_lr_eta张量。
"""
ttt_lr = torch.einsum("bnkc,hdc->bhnkd", X, self.learnable_ttt_lr_weight) + self.learnable_ttt_lr_bias.reshape(
1, -1, 1, 1, 1
)
ttt_lr = F.sigmoid(ttt_lr)
ttt_lr = ttt_lr.permute(0, 1, 2, 4, 3)
ttt_lr_eta = self.config.ttt_base_lr * ttt_lr / self.head_dim
token_idx = self.token_idx + self.learnable_token_idx
token_idx = token_idx[mini_batch_step_offset : mini_batch_step_offset + mini_batch_size]
token_idx = torch.clamp_min(token_idx, 0.0)
token_eta = torch.broadcast_to(
token_idx.reshape(1, 1, 1, mini_batch_size, 1),
(X.shape[0], self.num_heads, X.shape[1], mini_batch_size, 1),
)
return token_eta, ttt_lr_eta
def apply_gate(self, hidden_states, ttt_output):
"""应用门控机制
Args:
hidden_states (torch.Tensor): 输入隐藏状态。
ttt_output (torch.Tensor): TTT输出。
Returns:
torch.Tensor: 门控后的输出。
"""
y = self.g_proj(hidden_states)
y = F.gelu(y, approximate="tanh")
output = y * ttt_output
return output
def get_ttt_inputs(self, inputs, mini_batch_size, cache_params):
"""获取TTT输入
Args:
inputs (dict): 输入数据。
mini_batch_size (int): 微批处理大小。
cache_params (Optional[TTTCache]): 缓存参数。默认值为None。
Returns:
dict: TTT输入数据。
"""
XQ = inputs["XQ"]
XK = inputs["XK"]
XV = inputs["XV"]
X = inputs["X"]
B, L, C = X.shape
num_mini_batch = L // mini_batch_size
X = X.reshape(B, num_mini_batch, mini_batch_size, self.width)
XQ = XQ.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim)
XK = XK.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim)
XV = XV.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim)
if cache_params is not None:
mini_batch_step_offset = cache_params.seqlen_offset % self.mini_batch_size
else:
mini_batch_step_offset = 0
token_eta, ttt_lr_eta = self.get_eta(X, mini_batch_step_offset, mini_batch_size)
eta = token_eta * ttt_lr_eta
inputs = {
"XQ": XQ,
"XK": XK,
"XV": XV,
"eta": eta,
"token_eta": token_eta,
"ttt_lr_eta": ttt_lr_eta,
}
return inputs
def ttt(
self,
inputs,
mini_batch_size,
last_mini_batch_params_dict,
cache_params: Optional[TTTCache] = None,
):
"""TTT方法,必须在TTTBase的子类中实现"""
raise NotImplementedError("ttt方法必须在TTTBase子类中实现。")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_params: Optional[TTTCache] = None,
):
"""前向传播函数
Args:
hidden_states (torch.Tensor): 输入隐藏状态张量。
attention_mask (Optional[torch.Tensor]): 注意力掩码。默认值为None。
position_ids (Optional[torch.LongTensor]): 位置ID。默认值为None。
cache_params (Optional[TTTCache]): 缓存参数。默认值为None。
Returns:
torch.Tensor: 输出隐藏状态张量。
"""
# 获取批处理大小和序列长度
B, L = hidden_states.shape[:2]
# 计算序列长度除以微批处理大小的余数
reminder_len = L % self.mini_batch_size
# 计算序列可以被微批处理大小整除的部分
num_mini_batch = L // self.mini_batch_size
last_mini_batch_params_dict = None
# 获取Q/K/V投影
XQ, XK, XV = self.get_qkv_projections(hidden_states, cache_params=cache_params)
# 将Q/K/V重新形状化并进行转置
XQ = XQ.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
XK = XK.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
XV = XV.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
# 获取旋转嵌入的余弦和正弦
cos, sin = self.rotary_emb(XV, position_ids % self.mini_batch_size)
# 应用旋转位置嵌入
XQ, XK = permute_qk(XQ, XK)
XQ, XK = apply_rotary_pos_emb(XQ, XK, cos, sin)
XQ, XK = undo_permute_qk(XQ, XK)
# 初始化输出隐藏状态列表
output_hidden_states = []
# 如果序列长度可以被微批处理大小整除
if num_mini_batch > 0:
# 构建输入字典
inputs = {
"XQ": XQ[:, :, : num_mini_batch * self.mini_batch_size],
"XK": XK[:, :, : num_mini_batch * self.mini_batch_size],
"XV": XV[:, :, : num_mini_batch * self.mini_batch_size],
"X": hidden_states[:, : num_mini_batch * self.mini_batch_size],
}
# 获取TTT输入并调用TTT方法
output_mod, last_mini_batch_params_dict = self.ttt(
self.get_ttt_inputs(inputs, self.mini_batch_size, cache_params),
mini_batch_size=self.mini_batch_size,
last_mini_batch_params_dict=last_mini_batch_params_dict,
cache_params=cache_params,
)
# 将输出添加到输出隐藏状态列表中
output_hidden_states.append(output_mod)
# 如果序列长度有余数部分
if reminder_len > 0:
# 构建输入字典
inputs = {
"XQ": XQ[:, :, -reminder_len:],
"XK": XK[:, :, -reminder_len:],
"XV": XV[:, :, -reminder_len:],
"X": hidden_states[:, -reminder_len:],
}
# 获取TTT输入并调用TTT方法
output_reminder, _ = self.ttt(
self.get_ttt_inputs(inputs, reminder_len, cache_params),
mini_batch_size=reminder_len,
last_mini_batch_params_dict=last_mini_batch_params_dict,
cache_params=cache_params,
)
# 将输出添加到输出隐藏状态列表中
output_hidden_states.append(output_reminder)
# 将所有输出隐藏状态拼接在一起
output_hidden_states = torch.cat(output_hidden_states, dim=1)
# 应用层归一化
output_hidden_states = self.post_norm(output_hidden_states)
# 如果使用门控机制
if self.use_gate:
output_hidden_states = self.apply_gate(hidden_states, output_hidden_states)
# 应用输出投影
output_hidden_states = self.o_proj(output_hidden_states)
return output_hidden_states
TTTLinear 和 TTTMLP
TTTLinear 和 TTTMLP,它们继承自 TTTBase 并实现具体的 TTT 逻辑
class TTTLinear(TTTBase):
def __init__(self, config: TTTConfig, layer_idx: Optional[int] = None):
"""初始化TTT-Linear层
Args:
config (TTTConfig): 模型配置对象。
layer_idx (Optional[int]): 当前层的索引。默认值为None。
"""
super().__init__(config, layer_idx)
# 初始化TTT-Linear模型参数
self.W1 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, self.head_dim, self.head_dim)))
self.b1 = nn.Parameter(torch.zeros(self.num_heads, 1, self.head_dim))
def ttt(
self,
inputs,
mini_batch_size,
last_mini_batch_params_dict,
cache_params: Optional[TTTCache] = None,
):
"""实现TTT方法
Args:
inputs (dict): 输入数据。
mini_batch_size (int): 微批处理大小。
last_mini_batch_params_dict (dict): 上一个微批处理的参数字典。
cache_params (Optional[TTTCache]): 缓存参数。默认值为None。
Returns:
Tuple[torch.Tensor, dict]: 输出张量和更新后的参数字典。
"""
if mini_batch_size is None:
mini_batch_size = self.mini_batch_size
# 如果缓存参数存在但上一个微批处理的参数字典不存在,则获取缓存参数的字典
if last_mini_batch_params_dict is None and cache_params is not None:
last_mini_batch_params_dict = cache_params.ttt_params_to_dict(self.layer_idx)
# 获取输入张量的形状参数
B = inputs["XV"].shape[0] # 批处理大小
num_mini_batch = inputs["XV"].shape[2] # 微批处理数量
L = inputs["XV"].shape[2] * inputs["XV"].shape[3] # 总序列长度
device = inputs["XV"].device # 设备
dtype = inputs["XV"].dtype # 数据类型
# 判断是否使用对偶形式进行计算
use_dual_form = cache_params is None or mini_batch_size % self.mini_batch_size == 0
def compute_mini_batch(params_dict, inputs):
"""计算微批处理
Args:
params_dict (dict): 参数字典。
inputs (dict): 输入数据。
Returns:
Tuple[dict, torch.Tensor]: 更新后的参数字典和输出张量。
"""
# 获取初始参数
W1_init = params_dict["W1_states"]
b1_init = params_dict["b1_states"]
# 获取微批处理输入
XQ_mini_batch = inputs["XQ"]
XV_mini_batch = inputs["XV"]
XK_mini_batch = inputs["XK"]
eta_mini_batch = inputs["eta"]
token_eta_mini_batch = inputs["token_eta"]
ttt_lr_eta_mini_batch = inputs["ttt_lr_eta"]
# 计算线性变换
X1 = XK_mini_batch
Z1 = X1 @ W1_init + b1_init
reconstruction_target = XV_mini_batch - XK_mini_batch
# 获取层归一化参数
ln_weight = self.ttt_norm_weight.reshape(self.num_heads, 1, self.head_dim)
ln_bias = self.ttt_norm_bias.reshape(self.num_heads, 1, self.head_dim)
# 计算梯度
grad_l_wrt_Z1 = ln_fused_l2_bwd(Z1, reconstruction_target, ln_weight, ln_bias)
if use_dual_form:
# 对偶形式计算
Attn1 = torch.tril(XQ_mini_batch @ X1.transpose(-2, -1))
b1_bar = b1_init - torch.tril(eta_mini_batch) @ grad_l_wrt_Z1
Z1_bar = XQ_mini_batch @ W1_init - (eta_mini_batch * Attn1) @ grad_l_wrt_Z1 + b1_bar
last_eta_mini_batch = eta_mini_batch[:, :, -1, :, None]
W1_last = W1_init - (last_eta_mini_batch * X1).transpose(-1, -2) @ grad_l_wrt_Z1
b1_last = b1_init - torch.sum(last_eta_mini_batch * grad_l_wrt_Z1, dim=-2, keepdim=True)
grad_W1_last = torch.zeros_like(W1_last)
grad_b1_last = torch.zeros_like(b1_last)
else:
# 原始形式计算
ttt_lr_eta_mini_batch = torch.broadcast_to(
ttt_lr_eta_mini_batch,
(
*ttt_lr_eta_mini_batch.shape[:2],
mini_batch_size,
mini_batch_size,
),
)
grad_W1 = torch.einsum("bhki,bhkj->bhkij", X1, grad_l_wrt_Z1)
grad_W1 = torch.einsum("bhnk,bhkij->bhnij", torch.tril(ttt_lr_eta_mini_batch), grad_W1)
grad_W1 = grad_W1 + params_dict["W1_grad"].unsqueeze(2)
grad_b1 = torch.einsum("bhnk,bhki->bhni", torch.tril(ttt_lr_eta_mini_batch), grad_l_wrt_Z1)
grad_b1 = grad_b1 + params_dict["b1_grad"]
W1_bar = W1_init.unsqueeze(2) - grad_W1 * token_eta_mini_batch.unsqueeze(-1)
b1_bar = b1_init - grad_b1 * token_eta_mini_batch
Z1_bar = (XQ_mini_batch.unsqueeze(3) @ W1_bar).squeeze(3) + b1_bar
W1_last = W1_bar[:, :, -1]
b1_last = b1_bar[:, :, -1:]
grad_W1_last = grad_W1[:, :, -1]
grad_b1_last = grad_b1[:, :, -1:]
# 应用层归一化
Z1_bar = ln_fwd(Z1_bar, ln_weight, ln_bias)
# 更新Q投影
XQW_mini_batch = XQ_mini_batch + Z1_bar
last_param_dict = {
"W1_states": W1_last,
"b1_states": b1_last,
"W1_grad": grad_W1_last,
"b1_grad": grad_b1_last,
}
return last_param_dict, XQW_mini_batch
# 初始化参数字典
if last_mini_batch_params_dict is not None:
init_params_dict = last_mini_batch_params_dict
else:
init_params_dict = {
"W1_states": torch.tile(self.W1.unsqueeze(0), dims=(B, 1, 1, 1)),
"b1_states": torch.tile(self.b1.unsqueeze(0), dims=(B, 1, 1, 1)),
}
init_params_dict.update(W1_grad=torch.zeros_like(init_params_dict["W1_states"]))
init_params_dict.update(b1_grad=torch.zeros_like(init_params_dict["b1_states"]))
# 转置输入数据
inputs = tree_map(lambda x: x.permute(2, 0, 1, 3, 4), inputs)
# 分配输出张量
XQW_batch = torch.empty(
(num_mini_batch, B, self.num_heads, mini_batch_size, self.head_dim),
device=device,
dtype=dtype,
)
# 扫描微批处理
batch_params_dict, XQW_batch = scan(
compute_mini_batch,
init_params_dict,
inputs,
XQW_batch,
self.config.scan_checkpoint_group_size if self.training else 0,
)
# 更新缓存参数
if cache_params is not None:
cache_params.update(batch_params_dict, self.layer_idx, L)
# 转置并重新形状化输出
XQW_batch = XQW_batch.permute(1, 0, 3, 2, 4)
XQW_batch = XQW_batch.reshape(B, L, self.width)
return XQW_batch, batch_params_dict
-
__init__
方法:- 初始化了TTT-Linear层,继承自
TTTBase
。 - 初始化了线性变换的权重
W1
和偏置b1
。
- 初始化了TTT-Linear层,继承自
-
ttt
方法:- 实现了TTT方法,用于处理微批处理的输入数据。
- 判断是否使用对偶形式进行计算。
-
compute_mini_batch
函数:- 计算微批处理的线性变换和梯度。
- 根据是否使用对偶形式,分别进行计算。
- 更新
W1
和b1
的状态和梯度。
-
初始化参数字典:
- 如果上一个微批处理的参数字典存在,则使用该字典进行初始化。
- 否则,使用模型的初始参数进行初始化。
-
转置输入数据:
- 将输入数据进行转置,以适应后续的计算流程。
-
分配输出张量:
- 分配一个空的输出张量
XQW_batch
。
- 分配一个空的输出张量
-
扫描微批处理:
- 使用
scan
函数对微批处理进行扫描计算。 - 更新
batch_params_dict
和XQW_batch
。
- 使用
-
更新缓存参数:
- 如果缓存参数存在,则更新缓存参数。
-
转置并重新形状化输出:
- 对输出进行转置并重新形状化。
class TTTMLP(TTTBase):
def __init__(self, config: TTTConfig, layer_idx: Optional[int] = None):
"""初始化TTT-MLP层
Args:
config (TTTConfig): 模型配置对象。
layer_idx (Optional[int]): 当前层的索引。默认值为None。
"""
super().__init__(config, layer_idx)
# 初始化TTT-MLP模型参数
self.W1 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, self.head_dim, 4 * self.head_dim)))
self.b1 = nn.Parameter(torch.zeros(self.num_heads, 1, 4 * self.head_dim))
self.W2 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, 4 * self.head_dim, self.head_dim)))
self.b2 = nn.Parameter(torch.zeros(self.num_heads, 1, self.head_dim))
def ttt(
self,
inputs,
mini_batch_size,
last_mini_batch_params_dict,
cache_params: Optional[TTTCache] = None,
):
"""实现TTT方法
Args:
inputs (dict): 输入数据。
mini_batch_size (int): 微批处理大小。
last_mini_batch_params_dict (dict): 上一个微批处理的参数字典。
cache_params (Optional[TTTCache]): 缓存参数。默认值为None。
Returns:
Tuple[torch.Tensor, dict]: 输出张量和更新后的参数字典。
"""
if mini_batch_size is None:
mini_batch_size = self.mini_batch_size
# 如果缓存参数存在但上一个微批处理的参数字典不存在,则获取缓存参数的字典
if last_mini_batch_params_dict is None and cache_params is not None:
last_mini_batch_params_dict = cache_params.ttt_params_to_dict(self.layer_idx)
# 获取输入张量的形状参数
B = inputs["XV"].shape[0] # 批处理大小
num_mini_batch = inputs["XV"].shape[2] # 微批处理数量
L = inputs["XV"].shape[2] * inputs["XV"].shape[3] # 总序列长度
device = inputs["XV"].device # 设备
dtype = inputs["XV"].dtype # 数据类型
# 判断是否使用对偶形式进行计算
use_dual_form = cache_params is None or mini_batch_size % self.mini_batch_size == 0
def compute_mini_batch(params_dict, inputs):
"""计算微批处理
Args:
params_dict (dict): 参数字典。
inputs (dict): 输入数据。
Returns:
Tuple[dict, torch.Tensor]: 更新后的参数字典和输出张量。
"""
# 获取初始参数
W1_init = params_dict["W1_states"]
b1_init = params_dict["b1_states"]
W2_init = params_dict["W2_states"]
b2_init = params_dict["b2_states"]
# 获取微批处理输入
XQ_mini_batch = inputs["XQ"]
XV_mini_batch = inputs["XV"]
XK_mini_batch = inputs["XK"]
eta_mini_batch = inputs["eta"]
token_eta_mini_batch = inputs["token_eta"]
ttt_lr_eta_mini_batch = inputs["ttt_lr_eta"]
# 计算线性变换和激活函数
X1 = XK_mini_batch
Z1 = X1 @ W1_init + b1_init
X2 = F.gelu(Z1, approximate="tanh")
Z2 = X2 @ W2_init + b2_init
reconstruction_target = XV_mini_batch - XK_mini_batch
# 获取层归一化参数
ln_weight = self.ttt_norm_weight.reshape(self.num_heads, 1, self.head_dim)
ln_bias = self.ttt_norm_bias.reshape(self.num_heads, 1, self.head_dim)
# 计算梯度
grad_l_wrt_Z2 = ln_fused_l2_bwd(Z2, reconstruction_target, ln_weight, ln_bias)
grad_l_wrt_Z1 = grad_l_wrt_Z2 @ W2_init.transpose(-2, -1) * gelu_bwd(Z1)
if use_dual_form:
# 对偶形式计算
Attn1 = torch.tril(XQ_mini_batch @ X1.transpose(-2, -1))
b1_bar = b1_init - torch.tril(eta_mini_batch) @ grad_l_wrt_Z1
Z1_bar = XQ_mini_batch @ W1_init - (eta_mini_batch * Attn1) @ grad_l_wrt_Z1 + b1_bar
X2_bar = F.gelu(Z1_bar, approximate="tanh")
Attn2 = torch.tril(X2_bar @ X2.transpose(-2, -1))
b2_bar = b2_init - torch.tril(eta_mini_batch) @ grad_l_wrt_Z2
Z2_bar = X2_bar @ W2_init - (eta_mini_batch * Attn2) @ grad_l_wrt_Z2 + b2_bar
last_eta_mini_batch = eta_mini_batch[:, :, -1, :, None]
W1_last = W1_init - (last_eta_mini_batch * X1).transpose(-1, -2) @ grad_l_wrt_Z1
b1_last = b1_init - torch.sum(last_eta_mini_batch * grad_l_wrt_Z1, dim=-2, keepdim=True)
W2_last = W2_init - (last_eta_mini_batch * X2).transpose(-1, -2) @ grad_l_wrt_Z2
b2_last = b2_init - torch.sum(last_eta_mini_batch * grad_l_wrt_Z2, dim=-2, keepdim=True)
grad_W1_last = torch.zeros_like(W1_last)
grad_b1_last = torch.zeros_like(b1_last)
grad_W2_last = torch.zeros_like(W2_last)
grad_b2_last = torch.zeros_like(b2_last)
else:
# 原始形式计算
ttt_lr_eta_mini_batch = torch.broadcast_to(
ttt_lr_eta_mini_batch,
(
*ttt_lr_eta_mini_batch.shape[:2],
mini_batch_size,
mini_batch_size,
),
)
grad_W2 = torch.einsum("bhki,bhkj->bhkij", X2, grad_l_wrt_Z2)
grad_W2 = torch.einsum("bhnk,bhkij->bhnij", torch.tril(ttt_lr_eta_mini_batch), grad_W2)
grad_W2 = grad_W2 + params_dict["W2_grad"].unsqueeze(2)
grad_b2 = torch.einsum("bhnk,bhki->bhni", torch.tril(ttt_lr_eta_mini_batch), grad_l_wrt_Z2)
grad_b2 = grad_b2 + params_dict["b2_grad"]
grad_W1 = torch.einsum("bhki,bhkj->bhkij", X1, grad_l_wrt_Z1)
grad_W1 = torch.einsum("bhnk,bhkij->bhnij", torch.tril(ttt_lr_eta_mini_batch), grad_W1)
grad_W1 = grad_W1 + params_dict["W1_grad"].unsqueeze(2)
grad_b1 = torch.einsum("bhnk,bhki->bhni", torch.tril(ttt_lr_eta_mini_batch), grad_l_wrt_Z1)
grad_b1 = grad_b1 + params_dict["b1_grad"]
W1_bar = W1_init.unsqueeze(2) - grad_W1 * token_eta_mini_batch.unsqueeze(-1)
b1_bar = b1_init - grad_b1 * token_eta_mini_batch
W2_bar = W2_init.unsqueeze(2) - grad_W2 * token_eta_mini_batch.unsqueeze(-1)
b2_bar = b2_init - grad_b2 * token_eta_mini_batch
Z1_bar = (XQ_mini_batch.unsqueeze(3) @ W1_bar).squeeze(3) + b1_bar
X2_bar = F.gelu(Z1_bar, approximate="tanh")
Z2_bar = (X2_bar.unsqueeze(3) @ W2_bar).squeeze(3) + b2_bar
W1_last = W1_bar[:, :, -1]
b1_last = b1_bar[:, :, -1:]
W2_last = W2_bar[:, :, -1]
b2_last = b2_bar[:, :, -1:]
grad_W1_last = grad_W1[:, :, -1]
grad_b1_last = grad_b1[:, :, -1:]
grad_W2_last = grad_W2[:, :, -1]
grad_b2_last = grad_b2[:, :, -1:]
# 应用层归一化
Z2_bar = ln_fwd(Z2_bar, ln_weight, ln_bias)
# 更新Q投影
XQW_mini_batch = XQ_mini_batch + Z2_bar
last_param_dict = {
"W1_states": W1_last,
"b1_states": b1_last,
"W2_states": W2_last,
"b2_states": b2_last,
"W1_grad": grad_W1_last,
"b1_grad": grad_b1_last,
"W2_grad": grad_W2_last,
"b2_grad": grad_b2_last,
}
return last_param_dict, XQW_mini_batch
# 初始化参数字典
if last_mini_batch_params_dict is not None:
init_params_dict = last_mini_batch_params_dict
else:
init_params_dict = {
"W1_states": torch.tile(self.W1.unsqueeze(0), dims=(B, 1, 1, 1)),
"b1_states": torch.tile(self.b1.unsqueeze(0), dims=(B, 1, 1, 1)),
"W2_states": torch.tile(self.W2.unsqueeze(0), dims=(B, 1, 1, 1)),
"b2_states": torch.tile(self.b2.unsqueeze(0), dims=(B, 1, 1, 1)),
}
init_params_dict.update(W1_grad=torch.zeros_like(init_params_dict["W1_states"]))
init_params_dict.update(b1_grad=torch.zeros_like(init_params_dict["b1_states"]))
init_params_dict.update(W2_grad=torch.zeros_like(init_params_dict["W2_states"]))
init_params_dict.update(b2_grad=torch.zeros_like(init_params_dict["b2_states"]))
# 转置输入数据
inputs = tree_map(lambda x: x.permute(2, 0, 1, 3, 4), inputs)
# 分配输出张量
XQW_batch = torch.empty(
(num_mini_batch, B, self.num_heads, mini_batch_size, self.head_dim),
device=device,
dtype=dtype,
)
# 扫描微批处理
batch_params_dict, XQW_batch = scan(
compute_mini_batch,
init_params_dict,
inputs,
XQW_batch,
self.config.scan_checkpoint_group_size if self.training else 0,
)
# 更新缓存参数
if cache_params is not None:
cache_params.update(batch_params_dict, self.layer_idx, L)
# 转置并重新形状化输出
XQW_batch = XQW_batch.permute(1, 0, 3, 2, 4)
XQW_batch = XQW_batch.reshape(B, L, self.width)
return XQW_batch, batch_params_dict
-
__init__
方法:- 初始化了TTT-MLP层,继承自
TTTBase
。 - 初始化了两层线性变换的权重
W1
和W2
以及对应的偏置b1
和b2
。
- 初始化了TTT-MLP层,继承自
-
ttt
方法:- 实现了TTT方法,用于处理微批处理的输入数据。
- 判断是否使用对偶形式进行计算。
-
compute_mini_batch
函数:- 计算微批处理的线性变换、激活函数和梯度。
- 根据是否使用对偶形式,分别进行计算。
- 更新
W1
、W2
以及对应的偏置b1
和b2
的状态和梯度。
-
初始化参数字典:
- 如果上一个微批处理的参数字典存在,则使用该字典进行初始化。
- 否则,使用模型的初始参数进行初始化。
-
转置输入数据:
- 将输入数据进行转置,以适应后续的计算流程。
-
分配输出张量:
- 分配一个空的输出张量
XQW_batch
。
- 分配一个空的输出张量
-
扫描微批处理:
- 使用
scan
函数对微批处理进行扫描计算。 - 更新
batch_params_dict
和XQW_batch
。
- 使用
-
更新缓存参数:
- 如果缓存参数存在,则更新缓存参数。
-
转置并重新形状化输出:
对输出进行转置并重新形状化。
Block
类:
-
Block
类是模型的一个组成部分,通常包含多个子层(如TTT层、MLP层等)。__init__
方法初始化了各种层(如TTT层、MLP层、卷积层、归一化层等),根据配置决定是否使用卷积层。forward
方法实现了前向传播过程,包含了TTT层、归一化和MLP层的计算,并应用了残差连接。
TTTPreTrainedModel
类:
TTTPreTrainedModel
类是预训练模型的基类。- 包含模型的配置类和基础模型前缀信息,支持梯度检查点功能。
_init_weights
方法初始化模型的权重,主要针对线性层和嵌入层。
class Block(nn.Module):
def __init__(self, config: TTTConfig, layer_idx: int):
"""初始化Block层
Args:
config (TTTConfig): 模型配置对象。
layer_idx (int): 当前层的索引。
"""
super().__init__()
self.hidden_size = config.hidden_size # 隐藏层大小
self.pre_conv = config.pre_conv # 是否使用卷积层
# 根据配置选择TTT层类型
if config.ttt_layer_type == "linear":
ttt_layer = TTTLinear
elif config.ttt_layer_type == "mlp":
ttt_layer = TTTMLP
else:
raise ValueError(f"Invalid ttt_layer_type: {config.ttt_layer_type}")
# 序列建模块
self.seq_modeling_block = ttt_layer(config=config, layer_idx=layer_idx)
# MLP层
self.mlp = SwiGluMLP(config)
# 如果启用卷积,则初始化卷积层
if self.pre_conv:
self.conv = Conv(config, layer_idx)
# 归一化层
self.seq_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layer_idx = layer_idx
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_params: Optional[TTTCache] = None,
) -> torch.Tensor:
"""前向传播函数
Args:
hidden_states (torch.Tensor): 输入隐藏状态张量。
attention_mask (Optional[torch.Tensor]): 注意力掩码。
position_ids (Optional[torch.LongTensor]): 位置ID。
cache_params (Optional[TTTCache]): 缓存参数。
Returns:
torch.Tensor: 输出隐藏状态张量。
"""
if self.pre_conv:
residual = hidden_states # 保存残差
hidden_states = self.conv(hidden_states, cache_params=cache_params) # 应用卷积层
hidden_states = residual + hidden_states # 残差连接
residual = hidden_states # 保存残差
hidden_states = self.seq_norm(hidden_states) # 应用归一化
# TTT层
hidden_states = self.seq_modeling_block(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
cache_params=cache_params,
)
hidden_states = residual + hidden_states # 残差连接
# 前馈网络
residual = hidden_states # 保存残差
hidden_states = self.ffn_norm(hidden_states) # 应用归一化
hidden_states = self.mlp(hidden_states) # 应用MLP层
hidden_states = residual + hidden_states # 残差连接
return hidden_states
class TTTPreTrainedModel(PreTrainedModel):
config_class = TTTConfig # 配置类
base_model_prefix = "model" # 基础模型前缀
supports_gradient_checkpointing = True # 支持梯度检查点
_no_split_modules = ["Block"] # 不拆分的模块列表
def _init_weights(self, module):
"""初始化权重
Args:
module (nn.Module): 需要初始化的模块。
"""
std = self.config.initializer_range # 初始化范围
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std) # 初始化线性层权重
if module.bias is not None:
module.bias.data.zero_() # 初始化线性层偏置
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std) # 初始化嵌入层权重
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() # 初始化嵌入层填充索引的权重
class TTTOutput和class TTTCausalLMOutput
@dataclass
class TTTOutput(ModelOutput):
"""
TTT模型输出类。
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层输出的隐藏状态序列。
cache_params (`TTTCache`):
模型在最后一个时间步的状态。可以在下一个 `input_ids` 的前向方法中使用,以避免提供旧的 `input_ids`。
hidden_states (`Tuple[torch.FloatTensor]`, optional):
模型各层的隐藏状态序列(可选)。
"""
last_hidden_state: Optional[torch.FloatTensor] = None # 最后一层隐藏状态
cache_params: Optional[TTTCache] = None # 缓存参数
hidden_states: Optional[Tuple[torch.FloatTensor]] = None # 各层隐藏状态序列(可选)
@dataclass
class TTTCausalLMOutput(ModelOutput):
"""
因果语言模型(或自回归)输出的基类。
Args:
loss (`torch.FloatTensor` of shape `(1,)`, optional):
语言模型损失(用于下一个token预测),仅在提供 `labels` 时返回。
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
语言建模头的预测分数(SoftMax之前每个词汇token的分数)。
cache_params (`TTTCache`):
模型在最后一个时间步的状态。可以在下一个 `input_ids` 的前向方法中使用,以避免提供旧的 `input_ids`。
hidden_states (`Tuple[torch.FloatTensor]`, optional):
模型各层的隐藏状态序列(可选)。
"""
loss: Optional[torch.FloatTensor] = None # 损失
logits: Optional[torch.FloatTensor] = None # 预测分数
cache_params: Optional[TTTCache] = None # 缓存参数
hidden_states: Optional[Tuple[torch.FloatTensor]] = None # 各层隐藏状态序列(可选)
class TTTModel
- 基于
TTTPreTrainedModel
的解码器模型类,包含多个Block
层。
forward:
- 校验输入参数确保不能同时指定
input_ids
和inputs_embeds
。 - 检查梯度检查点与缓存的兼容性。
- 根据
input_ids
获取输入嵌入,如果未提供inputs_embeds
。 - 初始化
cache_params
,如果启用了缓存。 - 计算位置ID的偏移量。
- 创建注意力掩码,如果未提供。
- 通过多个解码层进行计算,每层包括梯度检查点。
- 规范化最后的隐藏状态。
- 返回结果,支持返回字典格式或元组格式。
-
class TTTModel(TTTPreTrainedModel): """ 包含 *config.num_hidden_layers* 层的解码器。每一层都是一个 [`Block`] Args: config: TTTConfig """ def __init__(self, config: TTTConfig): super().__init__(config) self.padding_idx = config.pad_token_id # 填充索引 self.vocab_size = config.vocab_size # 词汇表大小 # 嵌入层 self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) # 多层结构 self.layers = nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) # 归一化层 self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # 是否使用梯度检查点 # 初始化权重并应用最终处理 self.post_init() def get_input_embeddings(self): """获取输入嵌入层""" return self.embed_tokens def set_input_embeddings(self, value): """设置输入嵌入层""" self.embed_tokens = value def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_params: Optional[TTTCache] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, use_cache: Optional[bool] = None, ) -> Union[Tuple, ModelOutput]: """前向传播函数 Args: input_ids (torch.LongTensor, optional): 输入ID。 attention_mask (Optional[torch.Tensor], optional): 注意力掩码。 position_ids (Optional[torch.LongTensor], optional): 位置ID。 inputs_embeds (Optional[torch.FloatTensor], optional): 输入嵌入。 cache_params (Optional[TTTCache], optional): 缓存参数。 output_hidden_states (Optional[bool], optional): 是否输出隐藏状态。 return_dict (Optional[bool], optional): 是否返回字典。 use_cache (Optional[bool], optional): 是否使用缓存。 Returns: Union[Tuple, ModelOutput]: 输出结果。 """ output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # 确保不能同时指定 input_ids 和 inputs_embeds,且必须指定其中一个 if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) # 检查梯度检查点与缓存的兼容性 if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False # 如果未提供 inputs_embeds,则根据 input_ids 获取输入嵌入 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # 如果未提供 cache_params 且使用缓存,则初始化 cache_params if cache_params is None and use_cache: cache_params = TTTCache(self, inputs_embeds.size(0)) # 计算位置ID的偏移量 seqlen_offset = 0 if cache_params is not None: seqlen_offset = cache_params.seqlen_offset position_ids = torch.arange( seqlen_offset, seqlen_offset + inputs_embeds.shape[1], dtype=torch.long, device=inputs_embeds.device, ).unsqueeze(0) hidden_states = inputs_embeds # 如果未提供 attention_mask,则创建一个全为1的掩码 if attention_mask is None: attention_mask = torch.ones_like(input_ids) # 解码层 all_hidden_states = () if output_hidden_states else None for decoder_layer in self.layers: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, cache_params, ) else: hidden_states = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, cache_params=cache_params, ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if use_cache: cache_params.seqlen_offset += inputs_embeds.shape[1] hidden_states = self.norm(hidden_states) # 添加最后一个解码层的隐藏状态 if output_hidden_states: all_hidden_states += (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) return TTTOutput( last_hidden_state=hidden_states, cache_params=cache_params if use_cache else None, hidden_states=all_hidden_states, )
class TTTForCausalLM
class TTTForCausalLM(TTTPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"] # 绑定权重键
def __init__(self, config):
super().__init__(config)
self.model = TTTModel(config) # 初始化解码器模型
self.vocab_size = config.vocab_size # 词汇表大小
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # 语言模型头
# 初始化权重并应用最终处理
self.post_init()
def get_input_embeddings(self):
"""获取输入嵌入层"""
return self.model.embed_tokens
def set_input_embeddings(self, value):
"""设置输入嵌入层"""
self.model.embed_tokens = value
def get_output_embeddings(self):
"""获取输出嵌入层"""
return self.lm_head
def set_output_embeddings(self, new_embeddings):
"""设置输出嵌入层"""
self.lm_head = new_embeddings
def set_decoder(self, decoder):
"""设置解码器"""
self.model = decoder
def get_decoder(self):
"""获取解码器"""
return self.model
def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
) -> Dict[str, Any]:
"""更新生成过程中的模型参数
Args:
outputs (ModelOutput): 模型输出。
model_kwargs (Dict[str, Any]): 模型参数。
Returns:
Dict[str, Any]: 更新后的模型参数。
"""
model_kwargs["cache_params"] = outputs.get("cache_params", None)
# 更新注意力掩码
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
dim=-1,
)
return model_kwargs
def prepare_inputs_for_generation(
self,
input_ids,
attention_mask=None,
cache_params: Optional[TTTCache] = None,
inputs_embeds=None,
**kwargs,
):
"""准备生成过程中的输入
Args:
input_ids (torch.LongTensor): 输入ID。
attention_mask (Optional[torch.Tensor], optional): 注意力掩码。
cache_params (Optional[TTTCache], optional): 缓存参数。
inputs_embeds (Optional[torch.FloatTensor], optional): 输入嵌入。
Returns:
Dict[str, Any]: 准备好的输入。
"""
# 仅保留最后一个token的input_ids,如果状态传递
if cache_params is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = attention_mask[:, -1].unsqueeze(-1) if attention_mask is not None else None
if inputs_embeds is not None and cache_params is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"cache_params": cache_params,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[TTTCache] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
*,
output_attentions: Optional[bool] = None,
) -> Union[Tuple, TTTCausalLMOutput]:
"""前向传播函数
Args:
input_ids (torch.LongTensor, optional): 输入ID。
attention_mask (Optional[torch.Tensor], optional): 注意力掩码。
position_ids (Optional[torch.LongTensor], optional): 位置ID。
inputs_embeds (Optional[torch.FloatTensor], optional): 输入嵌入。
cache_params (Optional[TTTCache], optional): 缓存参数。
labels (Optional[torch.LongTensor], optional): 标签。
output_hidden_states (Optional[bool], optional): 是否输出隐藏状态。
return_dict (Optional[bool], optional): 是否返回字典。
use_cache (Optional[bool], optional): 是否使用缓存。
output_attentions (Optional[bool], optional): 是否输出注意力。
Returns:
Union[Tuple, TTTCausalLMOutput]: 输出结果。
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
assert not output_attentions, "output_attentions is not available in TTTForCausalLM"
# 解码器输出包括 (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cache_params=cache_params,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# 移位以便 tokens < n 预测 n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# 展平 tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# 启用模型并行
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return TTTCausalLMOutput(
loss=loss,
logits=logits,
cache_params=outputs.cache_params,
hidden_states=outputs.hidden_states,
)
用于因果语言建模(Causal Language Modeling)的类,它继承自 TTTPreTrainedModel
。此类的主要作用是将一个预训练的 TTT 模型(TTTModel
)封装成一个用于语言生成的模型
-
初始化 (
__init__
): 初始化模型,包括解码器模型TTTModel
和语言模型头lm_head
,后者是一个线性层,用于将隐藏状态映射到词汇表大小的 logits 上。 -
输入嵌入 (
get_input_embeddings
和set_input_embeddings
): 提供获取和设置模型输入嵌入层的方法。 -
输出嵌入 (
get_output_embeddings
和set_output_embeddings
): 提供获取和设置模型输出嵌入层的方法,这通常用于调整模型的输出以适应不同的任务。 -
设置和获取解码器 (
set_decoder
和get_decoder
): 允许用户设置自定义的解码器,并能够获取当前使用的解码器。 -
更新生成参数 (
_update_model_kwargs_for_generation
): 在生成文本时,更新模型参数,如缓存参数和注意力掩码。 -
准备生成输入 (
prepare_inputs_for_generation
): 准备用于文本生成的输入,包括处理输入 ID、注意力掩码和缓存参数。 -
前向传播 (
forward
): 定义了模型的前向传播逻辑,包括处理输入 ID、注意力掩码、位置 ID、输入嵌入、缓存参数等,并计算损失和 logits。如果提供标签labels
,则计算交叉熵损失。最后,根据return_dict
参数决定返回格式。 -
因果语言模型输出 (
TTTCausalLMOutput
): 定义了因果语言模型的输出格式,包括损失、logits、缓存参数和隐藏状态。
TTTForCausalLM
类通过这些方法和属性,实现了一个完整的因果语言模型,能够用于文本生成任务,如文本续写、问答等。它利用了 TTT 模型的高效特性,并添加了语言模型头来进行词汇级别的预测