深入解析Instruction Tuning 的 Loss 计算(基于 Open-Instruct 框架)

Instruction tuning中loss怎么算的

# 训练函数
def train_red_model(model, train_dataloader, epochs=3, lr=6e-3):
    optimizer = AdamW(model.parameters(), lr=lr)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in train_dataloader:
            input_ids = batch["input_ids"].squeeze(1).to(device)  # 移除多余维度
            attention_mask = batch["attention_mask"].squeeze(1).to(device)
            
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)  # 自回归任务
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch + 1}/{epochs}, Average Loss: {avg_loss:.4f}")这种instruction tuning的loss是怎么算的,详细解释一下

在代码中,train_red_model 函数使用了一个自回归语言建模任务的损失计算方式,具体体现在 outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)loss = outputs.loss 这两行。这是一种常见的 instruction tuning 损失计算方法,适用于像Qwen2.5-3B这样的自回归语言模型(causal language model)。下面将详细解释这种损失的计算原理、背后的逻辑,以及它在instruction tuning中的意义。


1. 自回归语言建模的基本概念

自回归语言模型(例如Qwen2.5-3B、GPT系列)是通过预测序列中下一个token来训练的。给定一个输入序列 ( x = [ x 1 , x 2 , … , x n ] x = [x_1, x_2, \dots, x_n] x=[x1,x2,,xn] ),模型的目标是学习条件概率分布:
P ( x t ∣ x 1 , x 2 , … , x t − 1 ) P(x_t | x_1, x_2, \dots, x_{t-1}) P(xtx1,x2,,xt1)
其中 ( x t x_t xt ) 是第 ( t t t ) 个token,基于前面的上下文 ( x 1 , x 2 , … , x t − 1 x_1, x_2, \dots, x_{t-1} x1,x2,,xt1 ) 进行预测。

训练时,模型通过最大化整个序列的对数似然来优化参数,即最小化负对数似然损失(Negative Log-Likelihood, NLL):
L = − 1 n ∑ t = 1 n log ⁡ P ( x t ∣ x 1 , x 2 , … , x t − 1 ) \mathcal{L} = -\frac{1}{n} \sum_{t=1}^{n} \log P(x_t | x_1, x_2, \dots, x_{t-1}) L=n1t=1nlogP(xtx1,x2,,xt1)


2. Instruction Tuning中的数据与任务

在instruction tuning中,数据通常是对话或指令-回复对的形式。例如,你的 allenai/tulu-3-sft-mixture 数据集中,一个样本可能如下:

  • 输入"Instruction: Write a function to find palindromes\nResponse: def find_palindrome_titles..."

为了适配自回归语言模型,我们将输入和输出拼接为一个完整序列,并要求模型预测整个序列的token。这种方式假设模型不仅要理解指令,还要生成对应的回复。训练时,输入和标签是同一个序列labels=input_ids),但通过掩码机制确保模型只预测未来的token。


3. 代码中的损失计算细节

在代码中:

outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
loss = outputs.loss

model 是基于 AutoModelForCausalLM 的Qwen2.5-3B模型,调用其 forward 方法时,labels=input_ids 表示将输入序列同时作为预测目标。下面是损失计算的逐步拆解:

输入与标签
  • input_ids:形状为 [batch_size, seq_len],表示token化的输入序列。例如:
    [Instruction, :, Write, a, function, ..., Response, :, def, find, ...]
    
  • attention_mask:形状为 [batch_size, seq_len],标记有效token(1)和填充token(0)。
  • labels:设置为 input_ids,即模型的目标是重建整个输入序列。
自回归掩码

自回归模型使用因果掩码(causal mask),确保在预测第 ( t t t ) 个token时,只依赖前 ( t − 1 t-1 t1 ) 个token。这种掩码由Transformer的注意力机制自动实现:

  • 对于位置 ( t t t ),注意力只能看到 ( [ 1 , 2 , … , t − 1 ] [1, 2, \dots, t-1] [1,2,,t1] ) 的token,后续token被屏蔽。
  • 例如,给定序列 [A, B, C]
    • 预测 A 时无输入(仅初始化)。
    • 预测 B 时输入 A
    • 预测 C 时输入 A, B
前向传播与Logits

模型的 forward 方法返回一个对象(通常是 CausalLMOutputWithPast),包含:

  • logits:形状为 [batch_size, seq_len, vocab_size],表示每个位置的token预测概率分布。
  • loss:如果提供了 labels,会自动计算损失。

对于每个位置 ( t t t ):

  • 输入:( x 1 , x 2 , … , x t − 1 x_1, x_2, \dots, x_{t-1} x1,x2,,xt1 )
  • 输出:( logits t ∈ R vocab_size \text{logits}_t \in \mathbb{R}^{\text{vocab\_size}} logitstRvocab_size )
  • 目标:( x t x_t xt )(来自 labels
交叉熵损失

outputs.loss 是由Hugging Face内置的交叉熵损失函数计算的:
L = − 1 N ∑ i = 1 B ∑ t = 1 S 1 mask i , t = 1 log ⁡ P ( x i , t ∣ x i , 1 , … , x i , t − 1 ) \mathcal{L} = -\frac{1}{N} \sum_{i=1}^{B} \sum_{t=1}^{S} \mathbb{1}_{\text{mask}_{i,t}=1} \log P(x_{i,t} | x_{i,1}, \dots, x_{i,t-1}) L=N1i=1Bt=1S1maski,t=1logP(xi,txi,1,,xi,t1)
其中:

  • ( B B B ) 是批大小(batch_size)。
  • ( S S S ) 是序列长度(seq_len)。
  • ( 1 mask i , t = 1 \mathbb{1}_{\text{mask}_{i,t}=1} 1maski,t=1 ) 表示只对有效token(由 attention_mask 指定)计算损失,忽略padding部分。
  • ( P ( x i , t ∣ x i , 1 , … , x i , t − 1 ) P(x_{i,t} | x_{i,1}, \dots, x_{i,t-1}) P(xi,txi,1,,xi,t1) ) 是模型预测的概率,通过softmax从 logits 得到:
    P ( x i , t = k ) = exp ⁡ ( logits i , t , k ) ∑ j = 1 vocab_size exp ⁡ ( logits i , t , j ) P(x_{i,t} = k) = \frac{\exp(\text{logits}_{i,t,k})}{\sum_{j=1}^{\text{vocab\_size}} \exp(\text{logits}_{i,t,j})} P(xi,t=k)=j=1vocab_sizeexp(logitsi,t,j)exp(logitsi,t,k)
    其中 ( k k k ) 是目标token的索引。
损失的实际计算
  1. Shift操作
    • input_ids 被用作输入,labels 被用作目标,但 labels 会被左移一位(shifted),因为我们预测的是下一个token。
    • 例如,输入 [A, B, C]labels 会被处理为 [B, C, <ignore>],而 logits 只计算到倒数第二个位置。
  2. 忽略填充
    • attention_mask 确保损失只计算有效token部分。例如,若序列中有padding([A, B, <pad>]),则只对 [A, B] 计算损失。
  3. 平均
    • 最终损失是对所有有效token的平均负对数似然。

4. Instruction Tuning中的意义

在instruction tuning中,这种损失计算方式有以下特点:

自监督性质
  • 通过将 labels=input_ids,模型被训练为“自重建”整个序列,包括指令和回复。这本质上是一个监督形式的自回归任务。
  • 对于例子:
    Instruction: Write a function to find palindromes\nResponse: def find_palindrome_titles...
    
    模型会逐步预测每个token,从 Instruction 到最后的 ...,确保生成的回复符合指令。
条件生成
  • 由于指令在前,回复在后,模型学习到 ( P ( Response ∣ Instruction ) P(\text{Response} | \text{Instruction}) P(ResponseInstruction) )。训练时,指令部分的token被用来条件化回复部分的生成。
  • 例如,模型在看到 Instruction: Write a function 后,预测 Response: def,逐步生成完整代码。

5. 与传统分类任务的区别

相比分类任务(如RoBERTa在GLUE上的交叉熵损失),instruction tuning的损失有以下不同:

  • 目标:分类任务预测单一类别(如正面/负面),而这里预测整个序列。
  • 维度:分类损失基于 [batch_size, num_labels] 的logits,而自回归损失基于 [batch_size, seq_len, vocab_size]
  • 掩码:自回归任务使用因果掩码和注意力掩码,确保单向预测和忽略填充。

6. 代码中的具体实现细节

squeeze(1)
  • input_idsattention_mask 的初始形状可能是 [batch_size, 1, seq_len](因DataLoader返回的张量格式),squeeze(1) 移除多余的维度,变为 [batch_size, seq_len]
outputs.loss
  • Hugging Face的 AutoModelForCausalLM 自动处理了shift、掩码和损失计算。你无需手动实现,只需确保 labelsinput_ids 一致。
损失值含义
  • loss.item() 是标量,表示当前batch的平均负对数似然。越低表示模型预测越接近目标序列。

7. 可能的优化与调试

  • 检查数据:确保 input_idsattention_mask 正确对齐,避免padding干扰损失。
  • 调整学习率:6e-3可能偏高,若损失不收敛,可尝试1e-4或5e-5。
  • 添加验证:定期计算困惑度(exp(loss)),评估模型在验证集上的生成能力。

总结

train_red_model 函数中,损失是基于自回归语言建模的交叉熵损失,计算模型预测下一个token的准确性。结合instruction tuning,模型通过这种方式学习生成与指令匹配的回复。这种方法在对话、代码生成等任务中非常强大,体现了指令微调的核心思想。


深入解析 Open-Instruct 中 Instruction Tuning 的 Loss 计算

介绍基于 open-instruct 框架的SFT脚本中instruction tuning的损失(loss)是如何计算的。会逐步拆解代码中的关键部分,解释其背后的原理,并结合上下文分析其在指令微调中的意义。这篇博客面向熟悉深度学习和自然语言处理的读者。

随着大语言模型(LLM)的广泛应用,指令微调(Instruction Tuning)成为提升模型任务适配能力的重要方法。在AllenAI开发的 open-instruct 框架中,针对自回归语言模型(如LLaMA、Qwen等)的监督微调(Supervised Fine-Tuning, SFT)实现了一个高效的训练流程。本文将聚焦于该框架中损失(loss)的计算逻辑,结合代码逐步分析其实现细节,并探讨其在instruction tuning中的作用。


数据准备与输入格式

open-instruct 的SFT脚本中,训练数据通常来源于对话格式的数据集(如 allenai/tulu-3-sft-mixture),每个样本包含一个 messages 字段,表示多轮对话。例如:

{
  "messages": [
    {"role": "user", "content": "Write a function to find palindromes"},
    {"role": "assistant", "content": "def find_palindrome_titles..."}
  ]
}

数据预处理由 encode_sft_example 函数完成,其核心是将对话转换为适合自回归训练的输入-标签对:

input_ids = tokenizer.apply_chat_template(
    conversation=messages,
    tokenize=True,
    return_tensors="pt",
    padding=False,
    truncation=True,
    max_length=max_seq_length,
    add_generation_prompt=False,
)
labels = input_ids.clone()
  • input_ids:通过 apply_chat_template 将对话序列化为token ID序列。例如,结果可能是 [Instruction, :, Write, ..., Response, :, def, ...]
  • labels:初始时与 input_ids 相同,但后续会通过掩码调整。

关键点在于,labels 中非助手(assistant)部分的token会被设置为 -100,以排除其对损失的贡献:

for message_idx, message in enumerate(messages):
    if message["role"] != "assistant":
        # 计算非助手部分的起始和结束索引
        message_start_idx = ...  # 前文token数
        message_end_idx = ...    # 当前消息结束位置
        labels[:, message_start_idx:message_end_idx] = -100

这种掩码机制确保模型只对助手的回复部分计算损失,而指令部分仅作为上下文,不参与梯度更新。


训练循环中的损失计算

训练循环位于 main 函数中,核心逻辑如下:

for epoch in range(starting_epoch, args.num_train_epochs):
    for step, batch in enumerate(active_dataloader):
        with accelerator.accumulate(model):
            outputs = model(**batch, use_cache=False)
            if args.reduce_loss == "mean":
                loss = outputs.loss
            else:  # "sum"
                logits = outputs.logits
                labels = batch["labels"]
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
                loss_fct = torch.nn.CrossEntropyLoss(reduction="sum")
                shift_logits = shift_logits.view(-1, embedding_size)
                shift_labels = shift_labels.view(-1)
                loss = loss_fct(shift_logits, shift_labels)
            total_loss += loss.detach().float()
            accelerator.backward(loss)
输入与输出
  • batch:包含 input_idslabelsattention_mask,形状为 [batch_size, seq_len]
  • outputs:调用 model(**batch) 返回一个 CausalLMOutputWithPast 对象,包含:
    • logits:形状 [batch_size, seq_len, vocab_size],表示每个位置的token预测概率。
    • loss:如果提供了 labels,则自动计算的损失。
两种损失计算方式

脚本支持两种损失缩减方式,由 args.reduce_loss 参数控制(默认 "mean"):

  1. "mean" 模式

    • 直接使用 outputs.loss,这是Hugging Face transformers 内置的自回归损失:
      L mean = − 1 N ∑ i = 1 B ∑ t = 1 S 1 labels i , t ≠ − 100 log ⁡ P ( x i , t ∣ x i , 1 , … , x i , t − 1 ) \mathcal{L}_{\text{mean}} = -\frac{1}{N} \sum_{i=1}^{B} \sum_{t=1}^{S} \mathbb{1}_{\text{labels}_{i,t} \neq -100} \log P(x_{i,t} | x_{i,1}, \dots, x_{i,t-1}) Lmean=N1i=1Bt=1S1labelsi,t=100logP(xi,txi,1,,xi,t1)
    • 其中:
      • ( B B B ):批大小。
      • ( S S S ):序列长度。
      • ( N N N ):有效token数(labels != -100 的数量)。
      • ( P ( x i , t ∣ . . . ) P(x_{i,t} | ...) P(xi,t∣...) ):通过 logits 的softmax计算。
    • 特点
      • 内置损失自动处理了因果掩码(只预测下一个token)和标签掩码(忽略 -100)。
      • 对每个有效token的损失取平均,适合均衡样本权重。
  2. "sum" 模式

    • 手动计算交叉熵损失, summing 而非 averaging:
      shift_logits = logits[..., :-1, :].contiguous()
      shift_labels = labels[..., 1:].contiguous()
      loss_fct = torch.nn.CrossEntropyLoss(reduction="sum")
      loss = loss_fct(shift_logits.view(-1, embedding_size), shift_labels.view(-1))
      
    • 步骤
      1. Shift操作logits 去掉最后一位(无对应标签),labels 去掉第一位(无前文预测),对齐预测与目标。
      2. 展平:将张量变为 [batch_size * (seq_len-1), vocab_size][batch_size * (seq_len-1)]
      3. 交叉熵
        L sum = − ∑ i = 1 B ∑ t = 1 S − 1 1 labels i , t ≠ − 100 log ⁡ P ( x i , t ∣ x i , 1 , … , x i , t − 1 ) \mathcal{L}_{\text{sum}} = -\sum_{i=1}^{B} \sum_{t=1}^{S-1} \mathbb{1}_{\text{labels}_{i,t} \neq -100} \log P(x_{i,t} | x_{i,1}, \dots, x_{i,t-1}) Lsum=i=1Bt=1S11labelsi,t=100logP(xi,txi,1,,xi,t1)
    • 特点
      • 对所有有效token的损失求和,而不是平均。
      • 在高梯度累积(gradient accumulation)场景下,能更平等地对待每个token,避免样本长度的影响(详见后文讨论)。
梯度计算
  • accelerator.backward(loss):根据损失计算梯度,结合 accumulate 实现梯度累积。
  • 仅对 labels != -100 的token(如助手回复)产生梯度贡献。

损失的数学原理

自回归语言模型的目标是最大化序列的对数似然:
L = − ∑ t = 1 S log ⁡ P ( x t ∣ x 1 , … , x t − 1 ) \mathcal{L} = -\sum_{t=1}^{S} \log P(x_t | x_1, \dots, x_{t-1}) L=t=1SlogP(xtx1,,xt1)
在instruction tuning中:

  • 输入序列包含指令和回复。
  • labels 中的掩码(-100)限制损失仅计算回复部分。
  • 例如,对于序列 [Instruction, :, Write, Response, :, def]
    • input_ids[1, 2, 3, 4, 5, 6]
    • labels[-100, -100, -100, 4, 5, 6]
    • 损失仅基于 [4, 5, 6] 的预测。

"mean" vs "sum" 的差异

  • "mean":归一化到有效token数,适合样本长度差异较大时。
  • "sum":保留总损失规模,可能在长序列中放大影响,论文中提到这在对话任务(如AlpacaEval)中可提升5个百分点。

Instruction Tuning 中的意义
掩码的作用

通过将非助手部分的 labels 设置为 -100,模型专注于学习生成回复,而非重复指令。这种方式模拟了对话场景下的条件生成:
P ( Response ∣ Instruction ) P(\text{Response} | \text{Instruction}) P(ResponseInstruction)

  • 指令作为上下文,提供生成的基础。
  • 损失仅优化回复的生成质量。
"sum" 模式的优势

代码注释提到,使用 "sum" 可改进聊天模型性能(见 https://github.com/huggingface/transformers/issues/24725)。原因在于:

  • 在高梯度累积下,"mean" 会平均每个样本的损失,导致短样本的影响被放大。
  • "sum" 按token数加权,确保长序列的训练信号不被稀释。

代码中的关键实现细节
  1. 数据加载

    • DataCollatorForSeq2Seq 动态填充批次,确保 input_idslabels 对齐。
    • attention_mask 标记有效token,防止padding干扰。
  2. Shift操作

    • 自回归模型预测下一个token,因此 logitslabels 需要错位对齐。
  3. 加速器支持

    • accelerator 管理分布式训练和梯度累积,total_loss 在主进程上聚合。
  4. 日志与监控

    • logging_steps 步记录平均损失和token处理速度(TPS)。

调试与优化建议
  • 检查掩码:打印 labelsinput_ids,确保非助手部分正确置为 -100
  • 损失异常:若损失过高,可能是学习率(默认 2e-5)需调整,或数据预处理有误。
  • 对比实验:尝试 "mean""sum",观察在你的数据集(如 tulu-3-sft-mixture)上的性能差异。

结语

open-instruct 的SFT脚本通过自回归损失实现了instruction tuning的核心目标:让模型根据指令生成高质量回复。其损失计算结合了掩码机制(聚焦助手部分)和灵活的缩减方式("mean""sum"),既高效又适应多种场景。理解这些细节不仅有助于复现训练流程,还能为自定义微调策略提供启发。未来,你可以尝试调整掩码逻辑或损失权重,进一步优化模型在特定任务上的表现。

logits和labels移位解释

在自回归语言模型的训练中,logits 去掉最后一位和 labels 去掉第一位的“shift操作”是损失计算的核心步骤,用来对齐模型的预测与目标。这里的操作看似简单,但背后有深刻的逻辑,与自回归任务的预测机制紧密相关。下面我将详细解释为什么 logits 需要去掉最后一位,以及这一操作的目的和意义。


自回归语言建模的基本原理

自回归语言模型(如Qwen2.5-3B、LLaMA等)通过逐步预测序列中的下一个token来生成文本。给定一个输入序列 ( x = [ x 1 , x 2 , … , x n ] x = [x_1, x_2, \dots, x_n] x=[x1,x2,,xn] ),模型的目标是学习条件概率分布:
P ( x t ∣ x 1 , x 2 , … , x t − 1 ) P(x_t | x_1, x_2, \dots, x_{t-1}) P(xtx1,x2,,xt1)
其中:

  • ( x t x_t xt ) 是第 ( t t t ) 个token。
  • ( x 1 , x 2 , … , x t − 1 x_1, x_2, \dots, x_{t-1} x1,x2,,xt1 ) 是前文上下文。

在训练时,模型一次性处理整个序列,输出每个位置的预测概率(logits),然后与目标序列(labels)比较,计算损失。但由于自回归的因果性质,模型在每个位置 ( t t t ) 的预测是基于前 ( t − 1 t-1 t1 ) 个token的,因此需要对齐输入和输出的维度。


输入与输出的维度

假设输入序列为 [A, B, C](长度为3):

  • input_ids[A, B, C],形状 [batch_size, seq_len],这里 seq_len = 3
  • labels:通常与 input_ids 相同,[A, B, C],因为自回归任务的目标是重建输入序列。

模型处理后:

  • logits:形状 [batch_size, seq_len, vocab_size],即 [3, vocab_size],为每个位置生成一个词汇表大小的预测分布。

具体来说:

  • 位置 0(输入 [A]):预测下一个token(应为 B),输出 logits_0
  • 位置 1(输入 [A, B]):预测下一个token(应为 C),输出 logits_1
  • 位置 2(输入 [A, B, C]):预测下一个token(应为序列外的未知token),输出 logits_2

为什么需要 Shift 操作?

问题在于:logits 的最后一个位置没有对应的真实目标。让我们逐步分析:

  1. logits 的含义

    • logits 表示模型在每个位置对下一个token的预测。
    • 对于长度为 ( n n n ) 的序列,logits 有 ( n n n ) 个位置,分别对应:
      • logits[0]:预测 ( x 2 x_2 x2 )(基于 ( x 1 x_1 x1 ))。
      • logits[1]:预测 ( x 3 x_3 x3 )(基于 ( x 1 , x 2 x_1, x_2 x1,x2 ))。
      • logits[n-1]:预测 ( x n x_n xn )(基于 ( x 1 , … , x n − 1 x_1, \dots, x_{n-1} x1,,xn1 ))。
      • logits[n]:预测序列外的下一个token(基于 ( x 1 , … , x n x_1, \dots, x_n x1,,xn ))。
  2. labels 的含义

    • labels 是我们希望模型预测的目标序列,通常与 input_ids 相同。
    • 对于 [A, B, C]labels = [A, B, C],但在自回归任务中,我们关心的是“下一个token”:
      • 基于 [A] 预测 B
      • 基于 [A, B] 预测 C
      • 基于 [A, B, C] 预测什么?——没有后续token可用。
  3. 维度不对齐

    • logits 有 ( n ) 个位置([logits_0, logits_1, logits_2])。
    • 但有效目标只有 ( n-1 ) 个([B, C]),因为最后一个 logits_2 预测的是序列外的token,而训练数据中没有提供这个目标。

Shift 操作的具体作用

为了解决上述问题,shift_logitsshift_labels 通过“错位”对齐预测和目标:

代码中的实现
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
  • shift_logits = logits[..., :-1, :]

    • 去掉 logits 的最后一位。
    • logits[logits_0, logits_1, logits_2](形状 [batch_size, 3, vocab_size])。
    • shift_logits[logits_0, logits_1](形状 [batch_size, 2, vocab_size])。
    • 原因logits_2 是基于 [A, B, C] 预测的下一个token,但训练数据中没有对应的真实目标(序列到 C 结束),所以丢弃它。
  • shift_labels = labels[..., 1:]

    • 去掉 labels 的第一位。
    • labels[A, B, C](形状 [batch_size, 3])。
    • shift_labels[B, C](形状 [batch_size, 2])。
    • 原因labels 的第一位 A 没有前文可预测(模型从空序列开始无法直接生成 A),而我们关心的是基于前文预测的后续token。
对齐后的结果
  • shift_logits[logits_0, logits_1]
    • logits_0:预测 B
    • logits_1:预测 C
  • shift_labels[B, C]
    • 目标 B:与 logits_0 对齐。
    • 目标 C:与 logits_1 对齐。

现在,shift_logitsshift_labels 的长度一致(均为 ( n − 1 n-1 n1 )),可以直接用于交叉熵损失计算:
L = − ∑ t = 1 n − 1 log ⁡ P ( x t + 1 ∣ x 1 , … , x t ) \mathcal{L} = -\sum_{t=1}^{n-1} \log P(x_{t+1} | x_1, \dots, x_t) L=t=1n1logP(xt+1x1,,xt)


图解说明

[A, B, C] 为例:

输入序列:      [A,    B,    C]
logits:       [logits_0, logits_1, logits_2]
                |预测B     |预测C     |预测下一个未知token
labels:       [A,    B,    C]
shift_logits: [logits_0, logits_1]
shift_labels:       [B,    C]
  • logits_0(基于 [A])预测 B,与 shift_labels[0]B)比较。
  • logits_1(基于 [A, B])预测 C,与 shift_labels[1]C)比较。
  • logits_2(基于 [A, B, C])无对应目标,丢弃。

为什么 logits 去掉最后一位是必要的?

  1. 无目标问题

    • logits[n-1](即 logits_2)预测的是序列外的下一个token,但训练数据只提供到 C,没有后续token的真实值。
    • 如果保留 logits_2,无法为其分配一个合理的 label,会导致损失计算出错或引入噪声。
  2. 因果性

    • 自回归模型的注意力机制使用因果掩码(causal mask),确保位置 ( t t t ) 只依赖 ( [ 1 , … , t − 1 ] [1, \dots, t-1] [1,,t1] )。
    • logits 的最后一个位置仍然会生成预测,只是没有意义(因为没有监督信号)。
  3. 损失计算需求

    • 交叉熵损失要求预测(logits)和目标(labels)的维度匹配。
    • 去掉 logits 最后一位后,shift_logitsshift_labels 的长度均为 ( n − 1 n-1 n1 ),完美对齐。

在 Instruction Tuning 中的意义

在instruction tuning中,序列通常包含指令和回复,例如:

[Instruction, :, Write, Response, :, def, find]
  • input_ids[1, 2, 3, 4, 5, 6, 7]

  • labels(掩码后):[-100, -100, -100, 4, 5, 6, 7](只训练回复部分)。

  • logits[l_0, l_1, l_2, l_3, l_4, l_5, l_6]

  • shift_logits[l_0, l_1, l_2, l_3, l_4, l_5]

  • shift_labels[-100, -100, -100, 4, 5, 6]

  • l_3 预测 4Response),l_4 预测 5:),依此类推。

  • 最后一位 l_6(基于 [1, 2, 3, 4, 5, 6, 7])预测序列外的token,无对应目标,故丢弃。

掩码(-100)进一步确保损失只计算回复部分的预测,而 logits 去掉最后一位则是为了保持维度一致性。


总结

logits 去掉最后一位是为了剔除没有真实目标的预测,使其与 shift_labels 对齐。这种“shift操作”是自回归语言建模的标准做法,确保损失计算只基于有效的前 ( n-1 ) 个预测。结合instruction tuning的掩码机制,这一操作让模型专注于生成回复,同时保持计算的正确性和效率。理解这一点有助于深入掌握自回归训练的细节,也为调试和优化提供了基础。

后记

2025年3月25日14点15分于上海,在grok 3大模型辅助下完成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值